Skip to content

Commit b284f72

Browse files
committed
✨ Implement computation of resulting function signatures
1 parent 6f5da11 commit b284f72

File tree

1 file changed

+205
-8
lines changed

1 file changed

+205
-8
lines changed

combinator.py

Lines changed: 205 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import itertools
3030
import os
3131
import sys
32-
from typing import cast, ClassVar, Dict, List, override, Optional, Set, Tuple
32+
from typing import cast, ClassVar, Dict, List, override, Optional, Self, Set, Tuple
3333

3434

3535

@@ -215,6 +215,10 @@ def collect_free_vars(self) -> Set[str]:
215215
"""Return a set of all the variables that are free for the entire expression."""
216216
return set()
217217

218+
def get_apps(self) -> List[Obj]:
219+
"""Return a list of all the applications in the expression."""
220+
return []
221+
218222

219223
class Var(Obj):
220224
"""Object to represent a variable in the lambda expression graph. This implements
@@ -224,6 +228,7 @@ class Var(Obj):
224228
def __init__(self, freename: Optional[str] = None):
225229
self.id = Var.varcnt
226230
self.freename = freename
231+
self.identical_to: Optional[Var] = None
227232
Var.varcnt += 1
228233

229234
@override
@@ -234,7 +239,7 @@ def is_free_in_context(self, v: Var) -> bool:
234239

235240
@override
236241
def __str__(self):
237-
return f'{{var {self.id}}}'
242+
return f'{{var {self.id}{self.identical_to.id}}}' if self.identical_to else f'{{var {self.id}}}'
238243

239244
@override
240245
def fmt(self, varmap: Naming, highlight: bool) -> str:
@@ -312,6 +317,10 @@ def fmt(self, varmap: Naming, highlight: bool) -> str:
312317
combres += ' ' + ' '.join([a.fmt(varmap, highlight) for a in self.arguments])
313318
return combres
314319

320+
@override
321+
def get_apps(self) -> List[Obj]:
322+
return list(itertools.chain.from_iterable(a.get_apps() for a in self.arguments))
323+
315324

316325
class Application(Obj):
317326
"""Object to represent the application (call) to a function in the lambda
@@ -353,6 +362,10 @@ def rmatch(self, other: Obj, var_map: Dict[Var, Obj]) -> bool:
353362
def collect_free_vars(self) -> Set[str]:
354363
return set().union(*[e.collect_free_vars() for e in self.code])
355364

365+
@override
366+
def get_apps(self) -> List[Obj]:
367+
return [self] + list(itertools.chain.from_iterable(a.get_apps() for a in self.code))
368+
356369
def beta(self) -> Obj:
357370
"""Perform beta reduction on the given application. This is called on a freshly
358371
created object but the reduction cannot be performed in the constructor because
@@ -442,6 +455,10 @@ def rmatch(self, other: Obj, var_map: Dict[Var, Obj]) -> bool:
442455
def collect_free_vars(self) -> Set[str]:
443456
return self.code.collect_free_vars()
444457

458+
@override
459+
def get_apps(self) -> List[Obj]:
460+
return self.code.get_apps()
461+
445462

446463
def parse_lambda(s: str, ctx: Dict[str, Var]) -> Tuple[Obj, str]:
447464
"""Parse the representation of a lambda definition. Return the graph
@@ -595,17 +612,170 @@ def to_string(expr: Obj, highlight: bool = False) -> str:
595612
return remove_braces(expr.recombine().fmt(Naming(expr.collect_free_vars()), highlight)).rstrip()
596613

597614

615+
class Vargen:
616+
"""Class to generate unique, consecutive variable names."""
617+
def __init__(self):
618+
self.typeidx: int = 0
619+
620+
def name(self):
621+
res = VARIABLE_NAMES[self.typeidx]
622+
self.typeidx += 1
623+
return res
624+
625+
626+
class Type:
627+
"""Simple type object used in type signatures."""
628+
def __init__(self):
629+
self.name = None
630+
631+
def __repr__(self):
632+
return f'Type({self.name})'
633+
634+
def __str__(self):
635+
assert self.name is not None
636+
return self.name
637+
638+
def finalize(self, gen: Vargen) -> Self:
639+
"""Create variable name."""
640+
if self.name is None:
641+
self.name = gen.name()
642+
return self
643+
644+
645+
class TypeFunc(Type):
646+
"""Type object for functions."""
647+
def __init__(self, app: Application):
648+
self.args = app
649+
self.ret = None
650+
651+
def __repr__(self):
652+
return f'TypeFunc({self.args})'
653+
654+
655+
def determine_types(obj: Obj, types: Dict[Obj, List[Type]]) -> Dict[Obj, List[Type]]:
656+
"""Traverse the expression object and extract type information and handles for variables and
657+
function invocations."""
658+
match obj:
659+
case Var():
660+
if obj not in types:
661+
types[cast(Var, obj)] = [Type()]
662+
case Application():
663+
assert obj not in types
664+
for a in obj.code:
665+
types = determine_types(a, types)
666+
tf = TypeFunc(obj)
667+
if obj.code[0] in types:
668+
# This can only happen for variables
669+
assert isinstance(obj.code[0], Var)
670+
if not isinstance(types[obj.code[0]][0], TypeFunc):
671+
types[obj.code[0]] = [tf]
672+
else:
673+
types[obj.code[0]].append(tf)
674+
else:
675+
types[obj.code[0]] = [tf]
676+
types[obj] = [Type()]
677+
case Lambda():
678+
types = determine_types(obj.code, dict())
679+
case Combinator():
680+
raise RuntimeError("Combinator cannot be handled in collect_types")
681+
case _:
682+
raise NotImplementedError(f"Unexpected type #1 {type(obj)}")
683+
return types
684+
685+
686+
def notation(obj: Obj, types: Dict[Obj, Type], gen: Vargen) -> str:
687+
"""Create a Haskell-like notation for the signature of the expression."""
688+
match obj:
689+
case Var():
690+
assert obj in types
691+
return str(types[obj].finalize(gen))
692+
case Lambda():
693+
l = []
694+
if isinstance(obj.code, Var):
695+
for p in obj.params:
696+
if p in types:
697+
l.append(str(types[p].finalize(gen)))
698+
else:
699+
l.append(str(Type().finalize(gen)))
700+
else:
701+
assert isinstance(obj.code, Application)
702+
for p in obj.params:
703+
assert p in types
704+
match types[p]:
705+
case TypeFunc():
706+
args = cast(TypeFunc, types[p]).args
707+
s = '('
708+
for c in args.code[1:]:
709+
s += f'{str(types[c].finalize(gen))} → '
710+
ttype = types[args].finalize(gen)
711+
l.append(f'{s}{str(ttype)})')
712+
case Type():
713+
rp = p.identical_to or p
714+
l.append(f'{str(types[rp].finalize(gen))}')
715+
case _:
716+
raise NotImplementedError(f"Unexpected type #3 {type(types[p])}")
717+
assert isinstance(types[obj.code], Type)
718+
l.append(str(types[obj.code].finalize(gen)))
719+
return ' → '.join(l)
720+
case _:
721+
raise NotImplementedError(f"Unexpected type #2 {type(obj)}")
722+
723+
724+
def mark_identical(l: Obj, r: Obj):
725+
"""If two objects are recognized after their creation to have the same type, mark them as identical."""
726+
assert type(l) == type(r)
727+
match l:
728+
case Var():
729+
cast(Var, r).identical_to = l
730+
case _:
731+
raise NotImplementedError(f"Unexpected type #3 {type(l)}")
732+
733+
734+
def to_typesig(expr: Obj, highlight: bool = False) -> str:
735+
"""Return a string representation for type signature of the expression."""
736+
gen = Vargen()
737+
if not isinstance(expr, Lambda):
738+
assert isinstance(expr, Var) or isinstance(expr, Application)
739+
t = Type()
740+
t.finalize(gen)
741+
return str(t)
742+
743+
types = determine_types(expr, dict())
744+
745+
stypes = {}
746+
rtypes = {}
747+
for k, t in types.items():
748+
if len(t) != 1:
749+
assert(all(isinstance(tt, TypeFunc) for tt in t))
750+
for tt in t[1:]:
751+
# We cannot modify types while iterating over it and we cannot add the corrected
752+
# value to stypes because the iteration order might cause the value to be overwritten.
753+
rtypes[cast(TypeFunc, tt).args] = types[cast(TypeFunc, t[0]).args][0]
754+
assert cast(TypeFunc, t[0]).args.code[0] == cast(TypeFunc, tt).args.code[0]
755+
for e in zip(cast(TypeFunc, t[0]).args.code[1:], cast(TypeFunc, tt).args.code[1:]):
756+
mark_identical(e[0], e[1])
757+
stypes[k] = t[0]
758+
759+
for e in rtypes:
760+
stypes[e] = rtypes[e]
761+
762+
return notation(expr, stypes, gen)
763+
764+
598765
def handle(a: str, echo: bool) -> int:
599766
"""Parse given string, simplify, and print the lambda expression."""
600767
ec = 0
601768
input_prompt = f'{COLORS["input_prompt"]}»{COLORS["off"]} '
602769
output_prompt = f'{COLORS["output_prompt"]}{COLORS["off"]} ' if IS_TERMINAL else ''
770+
typesig_prompt = f'{COLORS["output_prompt"]}🖊{COLORS["off"]} ' if IS_TERMINAL else ''
603771
separator_len = os.get_terminal_size()[0] if IS_TERMINAL else 72
604772

605773
if echo and IS_TERMINAL:
606774
print(f'{input_prompt}{a}')
607775
try:
608-
print(f'{output_prompt}{to_string(from_string(a), IS_TERMINAL)}')
776+
expr = from_string(a)
777+
print(f'{output_prompt}{to_string(expr, IS_TERMINAL)}')
778+
print(f'{typesig_prompt}{to_typesig(expr, IS_TERMINAL)}')
609779
except SyntaxError as e:
610780
print(f'eval("{a}") failed: {e.args[0]}')
611781
ec = 1
@@ -631,7 +801,7 @@ def repl() -> int:
631801

632802
def check() -> int:
633803
"""Sanity checks. Return error code that is used as the exit code of the process."""
634-
checks = [
804+
combinator_checks = [
635805
('S K K', 'I'),
636806
('K I', 'π'),
637807
('K (S K K)', 'π'),
@@ -700,17 +870,44 @@ def check() -> int:
700870
('B (a b) c', 'D a b c'),
701871
]
702872
ec = 0
703-
for testinput, expected in [(key, key) for key in KNOWN_COMBINATORS] + checks:
873+
print('Combinator checks')
874+
for testinput, expected in [(key, key) for key in KNOWN_COMBINATORS] + combinator_checks:
704875
resexpr = from_string(testinput)
705876
res = to_string(resexpr)
706877
if res != expected:
707878
if expected in KNOWN_COMBINATORS:
708-
print(f'❌ {testinput} {res} {resexpr} but {expected} {from_string(expected)} = {KNOWN_COMBINATORS[expected]} expected')
879+
print(f'❌ {testinput} {res} {resexpr} but {expected} {from_string(expected)} = {KNOWN_COMBINATORS[expected]} expected')
709880
else:
710-
print(f'❌ {testinput}{res} {resexpr} but {expected} {from_string(expected)} expected')
881+
print(f'❌ {testinput}{res} {resexpr} but {expected} {from_string(expected)} expected')
882+
ec = 1
883+
else:
884+
print(f'✅ {testinput}{res}')
885+
886+
signature_checks = [
887+
('B', '(a → b) → (c → a) → c → b'),
888+
('C', '(a → b → c) → b → a → c'),
889+
('C*', '(a → b → c → d) → a → c → b → d'),
890+
('I', 'a → a'),
891+
('I*', '(a → b) → a → b'),
892+
('K', 'a → b → a'),
893+
('Ψ', '(a → a → b) → (c → a) → c → c → b'),
894+
('Q', '(a → b) → (b → c) → a → c'),
895+
('R', 'a → (b → a → c) → b → c'),
896+
('R*', '(a → b → c → d) → c → a → b → d'),
897+
('S', '(a → b → c) → (a → b) → a → c'),
898+
('Φ', '(a → b → c) → (d → a) → (d → b) → d → c'),
899+
('T', 'a → (a → b) → b'),
900+
('W', '(a → a → b) → a → b'),
901+
]
902+
print('\nSignature checks')
903+
for testinput, expected in signature_checks:
904+
resexpr = from_string(testinput)
905+
res = to_typesig(resexpr)
906+
if res != expected:
907+
print(f'❌ {testinput}{res} but {expected} expected')
711908
ec = 1
712909
else:
713-
print(f'✅ {testinput} {res}')
910+
print(f'✅ {testinput} {res}')
714911
return ec
715912

716913

0 commit comments

Comments
 (0)