Skip to content

Commit

Permalink
Report line numbers on type errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
cpressey committed Feb 24, 2022
1 parent 204b6db commit 7179e74
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 43 deletions.
17 changes: 17 additions & 0 deletions eg/typecase-error.castile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
struct person { name: string };
fun foo(a, b: integer|string) {
r = a;
typecase b is integer {
r = r + b;
};
typecase b is person {
r = r + len(b);
};
r
}
main = fun() {
a = 0;
a = foo(a, 333 as integer|string);
a = foo(a, "hiya" as integer|string);
a /* should output 337 */
}
89 changes: 46 additions & 43 deletions src/castile/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@


class CastileTypeError(ValueError):
pass
def __init__(self, ast, message, *args, **kwargs):
message = 'line {}: {}'.format(ast.line, message)
super(CastileTypeError, self).__init__(message, *args, **kwargs)


class StructDefinition(object):
Expand Down Expand Up @@ -49,10 +51,10 @@ def set(self, name, type):
print('%s: %s' % (name, type))
return type

def assert_eq(self, t1, t2):
def assert_eq(self, ast, t1, t2):
if t1 == t2:
return
raise CastileTypeError("type mismatch: %s != %s" % (t1, t2))
raise CastileTypeError(ast, "type mismatch: %s != %s" % (t1, t2))

def collect_structs(self, ast):
for child in ast.children:
Expand All @@ -62,7 +64,7 @@ def collect_structs(self, ast):
def collect_struct(self, ast):
name = ast.value
if name in self.structs:
raise CastileTypeError('duplicate struct %s' % name)
raise CastileTypeError(ast, 'duplicate struct %s' % name)
struct_fields = {}
type_exprs = []
i = 0
Expand All @@ -74,7 +76,7 @@ def collect_struct(self, ast):
assert child.tag == 'FieldDefn', child.tag
field_name = child.value
if field_name in struct_fields:
raise CastileTypeError('already-defined field %s' % field_name)
raise CastileTypeError(child, 'already-defined field %s' % field_name)
struct_fields[field_name] = i
i += 1
type_exprs.append(self.type_of(child.children[0]))
Expand All @@ -83,7 +85,7 @@ def collect_struct(self, ast):
def resolve_structs(self, ast):
if isinstance(ast.type, Struct):
if ast.type.name not in self.structs:
raise CastileTypeError('undefined struct %s' % ast.type.name)
raise CastileTypeError(ast, 'undefined struct %s' % ast.type.name)
ast.type.defn = self.structs[ast.type.name]
for child in ast.children:
self.resolve_structs(child)
Expand All @@ -92,27 +94,27 @@ def resolve_structs(self, ast):
def type_of(self, ast):
if ast.tag == 'Op':
if ast.value in ('and', 'or'):
self.assert_eq(self.type_of(ast.children[0]), Boolean())
self.assert_eq(self.type_of(ast.children[1]), Boolean())
self.assert_eq(ast, self.type_of(ast.children[0]), Boolean())
self.assert_eq(ast, self.type_of(ast.children[1]), Boolean())
ast.type = Boolean()
elif ast.value in ('+', '-', '*', '/'):
type1 = self.type_of(ast.children[0])
type2 = self.type_of(ast.children[1])
self.assert_eq(type1, type2)
self.assert_eq(type1, Integer())
self.assert_eq(ast, type1, type2)
self.assert_eq(ast, type1, Integer())
ast.type = Integer()
elif ast.value in ('==', '!=', '>', '>=', '<', '<='):
type1 = self.type_of(ast.children[0])
type2 = self.type_of(ast.children[1])
self.assert_eq(type1, type2)
self.assert_eq(ast, type1, type2)
if isinstance(type1, Struct):
raise CastileTypeError("structs cannot be compared")
raise CastileTypeError(ast, "structs cannot be compared")
if isinstance(type1, Union) and type1.contains_instance_of(Struct):
raise CastileTypeError("unions containing structs cannot be compared")
raise CastileTypeError(ast, "unions containing structs cannot be compared")
ast.type = Boolean()
elif ast.tag == 'Not':
type1 = self.type_of(ast.children[0])
self.assert_eq(type1, Boolean())
self.assert_eq(ast, type1, Boolean())
ast.type = Boolean()
elif ast.tag == 'IntLit':
ast.type = Integer()
Expand All @@ -122,12 +124,13 @@ def type_of(self, ast):
ast.type = Boolean()
elif ast.tag == 'FunLit':
save_context = self.context
self.context = ScopedContext({}, self.toplevel_context,
level='argument')
self.context = ScopedContext(
{}, self.toplevel_context, level='argument'
)
self.return_type = None
arg_types = self.type_of(ast.children[0]) # args
t = self.type_of(ast.children[1]) # body
self.assert_eq(t, Void())
self.assert_eq(ast, t, Void())
self.context = save_context
return_type = self.return_type
self.return_type = None
Expand All @@ -152,7 +155,7 @@ def type_of(self, ast):
elif ast.tag == 'Body':
self.context = ScopedContext({}, self.context,
level='local')
self.assert_eq(self.type_of(ast.children[1]), Void())
self.assert_eq(ast, self.type_of(ast.children[1]), Void())
self.context = self.context.parent
ast.type = Void()
elif ast.tag == 'FunType':
Expand All @@ -165,7 +168,7 @@ def type_of(self, ast):
for c in ast.children:
type_ = self.type_of(c)
if type_ in types:
raise CastileTypeError("bad union type")
raise CastileTypeError(c, "bad union type")
types.append(type_)
ast.type = Union(types)
elif ast.tag == 'StructType':
Expand All @@ -180,18 +183,18 @@ def type_of(self, ast):
assert isinstance(t1, Function), \
'%r is not a function' % t1
if len(t1.arg_types) != len(ast.children) - 1:
raise CastileTypeError("argument mismatch")
raise CastileTypeError(ast, "argument mismatch")
i = 0
for child in ast.children[1:]:
self.assert_eq(self.type_of(child), t1.arg_types[i])
self.assert_eq(ast, self.type_of(child), t1.arg_types[i])
i += 1
ast.type = t1.return_type
elif ast.tag == 'Return':
t1 = self.type_of(ast.children[0])
if self.return_type is None:
self.return_type = t1
else:
self.assert_eq(t1, self.return_type)
self.assert_eq(ast, t1, self.return_type)
ast.type = Void()
elif ast.tag == 'Break':
ast.type = Void()
Expand All @@ -204,7 +207,7 @@ def type_of(self, ast):
if len(ast.children) == 3:
# TODO useless! is void.
t3 = self.type_of(ast.children[2])
self.assert_eq(t2, t3)
self.assert_eq(ast, t2, t3)
ast.type = t2
else:
ast.type = Void()
Expand All @@ -213,46 +216,46 @@ def type_of(self, ast):
within_control = self.within_control
self.within_control = True
t1 = self.type_of(ast.children[0])
self.assert_eq(t1, Boolean())
self.assert_eq(ast, t1, Boolean())
t2 = self.type_of(ast.children[1])
ast.type = Void()
self.within_control = within_control
elif ast.tag == 'Block':
for child in ast.children:
self.assert_eq(self.type_of(child), Void())
self.assert_eq(ast, self.type_of(child), Void())
ast.type = Void()
elif ast.tag == 'Assignment':
t2 = self.type_of(ast.children[1])
t1 = None
name = ast.children[0].value
if ast.aux == 'defining instance':
if self.within_control:
raise CastileTypeError('definition of %s within control block' % name)
raise CastileTypeError(ast, 'definition of %s within control block' % name)
if name in self.context:
raise CastileTypeError('definition of %s shadows previous' % name)
raise CastileTypeError(ast, 'definition of %s shadows previous' % name)
self.set(name, t2)
t1 = t2
else:
if name not in self.context:
raise CastileTypeError('variable %s used before definition' % name)
raise CastileTypeError(ast, 'variable %s used before definition' % name)
t1 = self.type_of(ast.children[0])
self.assert_eq(t1, t2)
self.assert_eq(ast, t1, t2)
# not quite useless now (typecase still likes this)
if self.context.level(ast.children[0].value) != 'local':
raise CastileTypeError('cannot assign to non-local')
raise CastileTypeError(ast, 'cannot assign to non-local')
ast.type = Void()
elif ast.tag == 'Make':
t = self.type_of(ast.children[0])
if t.name not in self.structs:
raise CastileTypeError("undefined struct %s" % t.name)
raise CastileTypeError(ast, "undefined struct %s" % t.name)
struct_defn = self.structs[t.name]
if struct_defn.scope_idents is not None:
if self.current_defn not in struct_defn.scope_idents:
raise CastileTypeError("inaccessible struct %s for make: %s not in %s" %
raise CastileTypeError(ast, "inaccessible struct %s for make: %s not in %s" %
(t.name, self.current_defn, struct_defn.scope_idents)
)
if len(struct_defn.content_types) != len(ast.children) - 1:
raise CastileTypeError("argument mismatch; expected {}, got {} in {}".format(
raise CastileTypeError(ast, "argument mismatch; expected {}, got {} in {}".format(
len(struct_defn.content_types), len(ast.children) - 1, ast
))
i = 0
Expand All @@ -261,7 +264,7 @@ def type_of(self, ast):
t1 = self.type_of(defn)
pos = struct_defn.field_names[name]
defn.aux = pos
self.assert_eq(t1, struct_defn.content_types[pos])
self.assert_eq(ast, t1, struct_defn.content_types[pos])
i += 1
ast.type = t
elif ast.tag == 'FieldInit':
Expand All @@ -271,13 +274,13 @@ def type_of(self, ast):
struct_defn = self.structs[t.name]
if struct_defn.scope_idents is not None:
if self.current_defn not in struct_defn.scope_idents:
raise CastileTypeError("inaccessible struct %s for access: %s not in %s" %
raise CastileTypeError(ast, "inaccessible struct %s for access: %s not in %s" %
(t.name, self.current_defn, struct_defn.scope_idents)
)
field_name = ast.value
struct_fields = struct_defn.field_names
if field_name not in struct_fields:
raise CastileTypeError("undefined field")
raise CastileTypeError(ast, "undefined field")
index = struct_fields[field_name]
# we make this value available to compiler backends
ast.aux = index
Expand All @@ -287,9 +290,9 @@ def type_of(self, ast):
t1 = self.type_of(ast.children[0])
t2 = self.type_of(ast.children[1])
if not isinstance(t1, Union):
raise CastileTypeError('bad typecase, %s not a union' % t1)
raise CastileTypeError(ast, 'bad typecase, %s not a union' % t1)
if not t1.contains(t2):
raise CastileTypeError('bad typecase, %s not in %s' % (t2, t1))
raise CastileTypeError(ast, 'bad typecase, %s not in %s' % (t2, t1))
# typecheck t3 with variable in children[0] having type t2
assert ast.children[0].tag == 'VarRef'
within_control = self.within_control
Expand All @@ -301,23 +304,23 @@ def type_of(self, ast):
self.within_control = within_control
elif ast.tag == 'Program':
for defn in ast.children:
self.assert_eq(self.type_of(defn), Void())
self.assert_eq(ast, self.type_of(defn), Void())
ast.type = Void()
self.resolve_structs(ast)
elif ast.tag == 'Defn':
self.current_defn = ast.value
t = self.type_of(ast.children[0])
self.current_defn = None
if ast.value in self.forwards:
self.assert_eq(self.forwards[ast.value], t)
self.assert_eq(ast, self.forwards[ast.value], t)
del self.forwards[ast.value]
else:
self.set(ast.value, t)
if ast.value == 'main':
# any return type is fine, for now, so,
# we compare it against itself
rt = t.return_type
self.assert_eq(t, Function([], rt))
self.assert_eq(ast, t, Function([], rt))
ast.type = Void()
elif ast.tag == 'Forward':
t = self.type_of(ast.children[0])
Expand All @@ -330,10 +333,10 @@ def type_of(self, ast):
value_t = self.type_of(ast.children[0])
union_t = self.type_of(ast.children[1])
if not isinstance(union_t, Union):
raise CastileTypeError('bad cast, not a union: %s' % union_t)
raise CastileTypeError(ast, 'bad cast, not a union: %s' % union_t)
if not union_t.contains(value_t):
raise CastileTypeError(
'bad cast, %s does not include %s' % (union_t, value_t)
ast, 'bad cast, %s does not include %s' % (union_t, value_t)
)
ast.type = union_t
else:
Expand Down

0 comments on commit 7179e74

Please sign in to comment.