Skip to content

Commit

Permalink
Unify type bounds: handle type params
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanjermakov committed Aug 18, 2024
1 parent 79d5197 commit f9afd5c
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 126 deletions.
2 changes: 1 addition & 1 deletion src/ast/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Context } from '../scope'
import { Hole, buildHole } from './match'
import { Identifier, Name, buildIdentifier, buildName } from './operand'

export type Type = Identifier | FnType | Hole
export type Type = Identifier | FnType | Hole | Name

export const buildType = (node: ParseNode, ctx: Context): Type => {
const n = filterNonAstNodes(node)[0]
Expand Down
7 changes: 5 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ const phases = [
desugar1,
resolveName,
setTopScopeType,
collectTypeBounds,
unifyTypeBounds
collectTypeBounds
]
phases.forEach(f => eachModule(f, ctx))
const m = ctx.packages.at(-1)!.modules[0]
ctx.moduleStack.push(m)
unifyTypeBounds(m, ctx)
ctx.moduleStack.pop()

reportErrors(ctx)
reportWarnings(ctx)
Expand Down
3 changes: 3 additions & 0 deletions src/phase/sugar.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { AstNode } from '../ast'
import { Context, idFromString } from '../scope'
import { unitType } from '../typecheck/type'

/**
* Desugar phase that runs before name resolution
* Does:
* - populates type of the `self` param
* - sets fnDef.instance
* - set fnDef.returnType to Unit if not specified
*/
export const desugar1 = (node: AstNode, ctx: Context, parent?: AstNode) => {
switch (node.kind) {
Expand All @@ -29,6 +31,7 @@ export const desugar1 = (node: AstNode, ctx: Context, parent?: AstNode) => {
}
})
}
node.returnType ??= unitType.type
break
}
}
Expand Down
94 changes: 68 additions & 26 deletions src/phase/top-scope-type.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { AstNode } from '../ast'
import { Identifier } from '../ast/operand'
import { Context } from '../scope'
import { makeInferredFromType, makeInferredType, makeTemplateType } from '../typecheck'
import { makeDefType, makeTypeParam } from '../typecheck'
import { unitType } from '../typecheck/type'
import { assert, todo, unreachable } from '../util/todo'

/**
* Set inferred types of topScope nodes
Expand All @@ -14,53 +14,95 @@ export const setTopScopeType = (node: AstNode, ctx: Context) => {
break
}
case 'var-def': {
if (node.pattern.expr.kind !== 'name') break
if (node.pattern.expr.kind !== 'name') return unreachable()
const def = node.pattern.expr
def.type = makeTemplateType(makeInferredFromType(node.varType!))
assert(!!node.varType)
def.type = makeDefType(node.varType!)
break
}
case 'fn-def': {
const generics = node.generics
if (node.instance) {
generics.push(...node.instance.generics)
node.generics.push(...node.instance.generics)
}
node.type = makeTemplateType({
kind: 'inferred-fn',
generics,
params: node.params.map(p => makeInferredFromType(p.paramType!)),
returnType: node.returnType ? makeInferredFromType(node.returnType) : unitType
node.generics.forEach(g => setTopScopeType(g, ctx))
node.params.forEach(p => setTopScopeType(p, ctx))
assert(!!node.returnType)
if (node.returnType) {
setTopScopeType(node.returnType, ctx)
}
node.type = makeDefType({
kind: 'fn-type',
generics: node.generics,
paramTypes: node.params.map(p => p.paramType!),
returnType: node.returnType ? node.returnType : unitType.type
})
break
}
case 'type-def': {
// TODO: generics
const nodeId: Identifier = {
kind: 'identifier',
parseNode: node.name.parseNode,
names: [node.name],
typeArgs: [],
def: node
}
node.type = makeTemplateType(nodeId)
node.type = makeDefType(node.name)
node.variants.forEach(v => {
v.type = makeTemplateType({
kind: 'inferred-fn',
v.fieldDefs.forEach(f => setTopScopeType(f, ctx))
const fnType = {
kind: <const>'fn-type',
generics: node.generics,
params: v.fieldDefs.map(f => makeInferredFromType(f.fieldType)),
returnType: nodeId
})
paramTypes: v.fieldDefs.map(f => f.fieldType),
returnType: node.name
}
setTopScopeType(fnType, ctx)
v.type = makeDefType(fnType)
})
break
}
case 'field-def': {
assert(!!node.fieldType)
setTopScopeType(node.fieldType!, ctx)
node.type = node.fieldType!.type!
setTopScopeType(node.name, ctx)
break
}
case 'trait-def':
case 'impl-def': {
if (node.kind === 'impl-def' && node.forTrait) break
node.generics.forEach(g => setTopScopeType(g, ctx))
node.block.statements.forEach(s => setTopScopeType(s, ctx))
break
}
case 'param': {
assert(!!node.paramType)
setTopScopeType(node.paramType!, ctx)
node.type = node.paramType!.type!
setTopScopeType(node.pattern, ctx)
break
}
case 'pattern': {
if (node.expr.kind !== 'name') {
return todo()
}
setTopScopeType(node.expr, ctx)
break
}
case 'identifier': {
node.typeArgs.forEach(ta => setTopScopeType(ta, ctx))
node.type = node.def?.type ?? { kind: 'error', message: 'no def' }
break
}
case 'name': {
node.type = node.def?.type ?? { kind: 'error', message: 'no def' }
break
}
case 'fn-type': {
node.generics.forEach(pt => setTopScopeType(pt, ctx))
node.paramTypes.forEach(pt => setTopScopeType(pt, ctx))
setTopScopeType(node.returnType, ctx)
node.type = makeDefType(node)
break
}
case 'hole': {
node.type = { kind: 'hole' }
break
}
case 'generic': {
node.type = makeInferredType()
node.type = makeTypeParam(node.name)
break
}
}
Expand Down
48 changes: 20 additions & 28 deletions src/phase/type-bound.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import { AstNode } from '../ast'
import { FnDef } from '../ast/statement'
import { Context, addError } from '../scope'
import { genericError } from '../semantic/error'
import { Context } from '../scope'
import { operatorImplMap } from '../semantic/op'
import {
InferredType,
addBounds,
instantiateTemplateType,
makeInferredFromType,
instantiateDefType,
makeDefType,
makeInferredType,
makeReturnType
} from '../typecheck'
import { boolType, charType, floatType, intType, stringType, unitType } from '../typecheck/type'
import { assert } from '../util/todo'
import { assert, unreachable } from '../util/todo'
import { findById, findParent } from './name-resolve'

/**
Expand Down Expand Up @@ -85,14 +84,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
break
}
case 'param': {
if (!node.paramType) break
collectTypeBounds(node.paramType, ctx)
const pType =
node.paramType.kind === 'identifier' && node.paramType.def
? node.paramType.def.type!
: makeInferredFromType(node.paramType)
addBounds(node.type!, [pType])
collectTypeBounds(node.pattern, ctx, pType)
// TODO
break
}
case 'generic': {
Expand Down Expand Up @@ -126,7 +118,8 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
case 'identifier':
case 'name': {
if (node.def) {
node.type = instantiateTemplateType(node.def.type!)
assert(!!node.def.type)
node.type = instantiateDefType(node.def.type!)
break
}
break
Expand All @@ -140,7 +133,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
collectTypeBounds(node.operand, ctx)
switch (node.op.kind) {
case 'call-op': {
const fnType = instantiateTemplateType(node.operand.type!)
const fnType = instantiateDefType(node.operand.type!)
node.op.args.forEach(a => collectTypeBounds(a, ctx))
addBounds(fnType, [boundFromCall(node.op.args.map(a => a.type!))])
node.type = makeReturnType(fnType)
Expand Down Expand Up @@ -168,7 +161,7 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
assert(!!methodId)
const methodDef = findById(methodId!, ctx)
assert(!!methodDef)
const fnType = instantiateTemplateType(methodDef!.type!)
const fnType = instantiateDefType(methodDef!.type!)
addBounds(fnType, [boundFromCall([node.lOperand.type!, node.rOperand.type!])])
node.type = makeReturnType(fnType)
break
Expand All @@ -195,22 +188,21 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
}
case 'var-def': {
if (node.expr) {
collectTypeBounds(node.expr, ctx, node.varType ? makeInferredFromType(node.varType) : undefined)
collectTypeBounds(node.expr, ctx, node.varType ? makeDefType(node.varType) : undefined)
}
collectTypeBounds(node.pattern, ctx, node.expr?.type)
node.type = instantiateTemplateType(unitType)
node.type = instantiateDefType(unitType)
break
}
case 'fn-def': {
node.generics.forEach(g => collectTypeBounds(g, ctx))
node.params.forEach(p => collectTypeBounds(p, ctx))
if (node.block) {
if (node.type?.kind !== 'template' || node.type.type.kind !== 'inferred-fn') {
addError(ctx, genericError(ctx, node, 'no type'))
if (node.type?.kind !== 'def' || node.type.type.kind !== 'fn-type') {
unreachable()
break
// return unreachable()
}
collectTypeBounds(node.block, ctx, node.type.type.returnType)
collectTypeBounds(node.block, ctx, makeDefType(node.type.type.returnType))
}
break
}
Expand All @@ -223,27 +215,27 @@ export const collectTypeBounds = (node: AstNode, ctx: Context, parentBound?: Inf
}
case 'string-interpolated': {
node.tokens.filter(t => typeof t !== 'string').forEach(t => collectTypeBounds(t, ctx))
node.type = instantiateTemplateType(stringType)
node.type = instantiateDefType(stringType)
break
}
case 'string-literal': {
node.type = instantiateTemplateType(stringType)
node.type = instantiateDefType(stringType)
break
}
case 'char-literal': {
node.type = instantiateTemplateType(charType)
node.type = instantiateDefType(charType)
break
}
case 'int-literal': {
node.type = instantiateTemplateType(intType)
node.type = instantiateDefType(intType)
break
}
case 'float-literal': {
node.type = instantiateTemplateType(floatType)
node.type = instantiateDefType(floatType)
break
}
case 'bool-literal': {
node.type = instantiateTemplateType(boolType)
node.type = instantiateDefType(boolType)
break
}
case 'method-call-op': {
Expand Down
Loading

0 comments on commit f9afd5c

Please sign in to comment.