29
29
import itertools
30
30
import os
31
31
import sys
32
- from typing import cast , ClassVar , Dict , List , override , Optional , Set , Tuple
32
+ from typing import cast , ClassVar , Dict , List , override , Optional , Self , Set , Tuple
33
33
34
34
35
35
@@ -215,6 +215,10 @@ def collect_free_vars(self) -> Set[str]:
215
215
"""Return a set of all the variables that are free for the entire expression."""
216
216
return set ()
217
217
218
+ def get_apps (self ) -> List [Obj ]:
219
+ """Return a list of all the applications in the expression."""
220
+ return []
221
+
218
222
219
223
class Var (Obj ):
220
224
"""Object to represent a variable in the lambda expression graph. This implements
@@ -224,6 +228,7 @@ class Var(Obj):
224
228
def __init__ (self , freename : Optional [str ] = None ):
225
229
self .id = Var .varcnt
226
230
self .freename = freename
231
+ self .identical_to : Optional [Var ] = None
227
232
Var .varcnt += 1
228
233
229
234
@override
@@ -234,7 +239,7 @@ def is_free_in_context(self, v: Var) -> bool:
234
239
235
240
@override
236
241
def __str__ (self ):
237
- return f'{{var { self .id } }}'
242
+ return f'{{var { self . id } → { self . identical_to . id } }}' if self . identical_to else f'{{var { self .id } }}'
238
243
239
244
@override
240
245
def fmt (self , varmap : Naming , highlight : bool ) -> str :
@@ -312,6 +317,10 @@ def fmt(self, varmap: Naming, highlight: bool) -> str:
312
317
combres += ' ' + ' ' .join ([a .fmt (varmap , highlight ) for a in self .arguments ])
313
318
return combres
314
319
320
+ @override
321
+ def get_apps (self ) -> List [Obj ]:
322
+ return list (itertools .chain .from_iterable (a .get_apps () for a in self .arguments ))
323
+
315
324
316
325
class Application (Obj ):
317
326
"""Object to represent the application (call) to a function in the lambda
@@ -353,6 +362,10 @@ def rmatch(self, other: Obj, var_map: Dict[Var, Obj]) -> bool:
353
362
def collect_free_vars (self ) -> Set [str ]:
354
363
return set ().union (* [e .collect_free_vars () for e in self .code ])
355
364
365
+ @override
366
+ def get_apps (self ) -> List [Obj ]:
367
+ return [self ] + list (itertools .chain .from_iterable (a .get_apps () for a in self .code ))
368
+
356
369
def beta (self ) -> Obj :
357
370
"""Perform beta reduction on the given application. This is called on a freshly
358
371
created object but the reduction cannot be performed in the constructor because
@@ -442,6 +455,10 @@ def rmatch(self, other: Obj, var_map: Dict[Var, Obj]) -> bool:
442
455
def collect_free_vars (self ) -> Set [str ]:
443
456
return self .code .collect_free_vars ()
444
457
458
+ @override
459
+ def get_apps (self ) -> List [Obj ]:
460
+ return self .code .get_apps ()
461
+
445
462
446
463
def parse_lambda (s : str , ctx : Dict [str , Var ]) -> Tuple [Obj , str ]:
447
464
"""Parse the representation of a lambda definition. Return the graph
@@ -595,17 +612,170 @@ def to_string(expr: Obj, highlight: bool = False) -> str:
595
612
return remove_braces (expr .recombine ().fmt (Naming (expr .collect_free_vars ()), highlight )).rstrip ()
596
613
597
614
615
+ class Vargen :
616
+ """Class to generate unique, consecutive variable names."""
617
+ def __init__ (self ):
618
+ self .typeidx : int = 0
619
+
620
+ def name (self ):
621
+ res = VARIABLE_NAMES [self .typeidx ]
622
+ self .typeidx += 1
623
+ return res
624
+
625
+
626
+ class Type :
627
+ """Simple type object used in type signatures."""
628
+ def __init__ (self ):
629
+ self .name = None
630
+
631
+ def __repr__ (self ):
632
+ return f'Type({ self .name } )'
633
+
634
+ def __str__ (self ):
635
+ assert self .name is not None
636
+ return self .name
637
+
638
+ def finalize (self , gen : Vargen ) -> Self :
639
+ """Create variable name."""
640
+ if self .name is None :
641
+ self .name = gen .name ()
642
+ return self
643
+
644
+
645
+ class TypeFunc (Type ):
646
+ """Type object for functions."""
647
+ def __init__ (self , app : Application ):
648
+ self .args = app
649
+ self .ret = None
650
+
651
+ def __repr__ (self ):
652
+ return f'TypeFunc({ self .args } )'
653
+
654
+
655
+ def determine_types (obj : Obj , types : Dict [Obj , List [Type ]]) -> Dict [Obj , List [Type ]]:
656
+ """Traverse the expression object and extract type information and handles for variables and
657
+ function invocations."""
658
+ match obj :
659
+ case Var ():
660
+ if obj not in types :
661
+ types [cast (Var , obj )] = [Type ()]
662
+ case Application ():
663
+ assert obj not in types
664
+ for a in obj .code :
665
+ types = determine_types (a , types )
666
+ tf = TypeFunc (obj )
667
+ if obj .code [0 ] in types :
668
+ # This can only happen for variables
669
+ assert isinstance (obj .code [0 ], Var )
670
+ if not isinstance (types [obj .code [0 ]][0 ], TypeFunc ):
671
+ types [obj .code [0 ]] = [tf ]
672
+ else :
673
+ types [obj .code [0 ]].append (tf )
674
+ else :
675
+ types [obj .code [0 ]] = [tf ]
676
+ types [obj ] = [Type ()]
677
+ case Lambda ():
678
+ types = determine_types (obj .code , dict ())
679
+ case Combinator ():
680
+ raise RuntimeError ("Combinator cannot be handled in collect_types" )
681
+ case _:
682
+ raise NotImplementedError (f"Unexpected type #1 { type (obj )} " )
683
+ return types
684
+
685
+
686
+ def notation (obj : Obj , types : Dict [Obj , Type ], gen : Vargen ) -> str :
687
+ """Create a Haskell-like notation for the signature of the expression."""
688
+ match obj :
689
+ case Var ():
690
+ assert obj in types
691
+ return str (types [obj ].finalize (gen ))
692
+ case Lambda ():
693
+ l = []
694
+ if isinstance (obj .code , Var ):
695
+ for p in obj .params :
696
+ if p in types :
697
+ l .append (str (types [p ].finalize (gen )))
698
+ else :
699
+ l .append (str (Type ().finalize (gen )))
700
+ else :
701
+ assert isinstance (obj .code , Application )
702
+ for p in obj .params :
703
+ assert p in types
704
+ match types [p ]:
705
+ case TypeFunc ():
706
+ args = cast (TypeFunc , types [p ]).args
707
+ s = '('
708
+ for c in args .code [1 :]:
709
+ s += f'{ str (types [c ].finalize (gen ))} → '
710
+ ttype = types [args ].finalize (gen )
711
+ l .append (f'{ s } { str (ttype )} )' )
712
+ case Type ():
713
+ rp = p .identical_to or p
714
+ l .append (f'{ str (types [rp ].finalize (gen ))} ' )
715
+ case _:
716
+ raise NotImplementedError (f"Unexpected type #3 { type (types [p ])} " )
717
+ assert isinstance (types [obj .code ], Type )
718
+ l .append (str (types [obj .code ].finalize (gen )))
719
+ return ' → ' .join (l )
720
+ case _:
721
+ raise NotImplementedError (f"Unexpected type #2 { type (obj )} " )
722
+
723
+
724
+ def mark_identical (l : Obj , r : Obj ):
725
+ """If two objects are recognized after their creation to have the same type, mark them as identical."""
726
+ assert type (l ) == type (r )
727
+ match l :
728
+ case Var ():
729
+ cast (Var , r ).identical_to = l
730
+ case _:
731
+ raise NotImplementedError (f"Unexpected type #3 { type (l )} " )
732
+
733
+
734
+ def to_typesig (expr : Obj , highlight : bool = False ) -> str :
735
+ """Return a string representation for type signature of the expression."""
736
+ gen = Vargen ()
737
+ if not isinstance (expr , Lambda ):
738
+ assert isinstance (expr , Var ) or isinstance (expr , Application )
739
+ t = Type ()
740
+ t .finalize (gen )
741
+ return str (t )
742
+
743
+ types = determine_types (expr , dict ())
744
+
745
+ stypes = {}
746
+ rtypes = {}
747
+ for k , t in types .items ():
748
+ if len (t ) != 1 :
749
+ assert (all (isinstance (tt , TypeFunc ) for tt in t ))
750
+ for tt in t [1 :]:
751
+ # We cannot modify types while iterating over it and we cannot add the corrected
752
+ # value to stypes because the iteration order might cause the value to be overwritten.
753
+ rtypes [cast (TypeFunc , tt ).args ] = types [cast (TypeFunc , t [0 ]).args ][0 ]
754
+ assert cast (TypeFunc , t [0 ]).args .code [0 ] == cast (TypeFunc , tt ).args .code [0 ]
755
+ for e in zip (cast (TypeFunc , t [0 ]).args .code [1 :], cast (TypeFunc , tt ).args .code [1 :]):
756
+ mark_identical (e [0 ], e [1 ])
757
+ stypes [k ] = t [0 ]
758
+
759
+ for e in rtypes :
760
+ stypes [e ] = rtypes [e ]
761
+
762
+ return notation (expr , stypes , gen )
763
+
764
+
598
765
def handle (a : str , echo : bool ) -> int :
599
766
"""Parse given string, simplify, and print the lambda expression."""
600
767
ec = 0
601
768
input_prompt = f'{ COLORS ["input_prompt" ]} »{ COLORS ["off" ]} '
602
769
output_prompt = f'{ COLORS ["output_prompt" ]} ⇒{ COLORS ["off" ]} ' if IS_TERMINAL else ''
770
+ typesig_prompt = f'{ COLORS ["output_prompt" ]} 🖊{ COLORS ["off" ]} ' if IS_TERMINAL else ''
603
771
separator_len = os .get_terminal_size ()[0 ] if IS_TERMINAL else 72
604
772
605
773
if echo and IS_TERMINAL :
606
774
print (f'{ input_prompt } { a } ' )
607
775
try :
608
- print (f'{ output_prompt } { to_string (from_string (a ), IS_TERMINAL )} ' )
776
+ expr = from_string (a )
777
+ print (f'{ output_prompt } { to_string (expr , IS_TERMINAL )} ' )
778
+ print (f'{ typesig_prompt } { to_typesig (expr , IS_TERMINAL )} ' )
609
779
except SyntaxError as e :
610
780
print (f'eval("{ a } ") failed: { e .args [0 ]} ' )
611
781
ec = 1
@@ -631,7 +801,7 @@ def repl() -> int:
631
801
632
802
def check () -> int :
633
803
"""Sanity checks. Return error code that is used as the exit code of the process."""
634
- checks = [
804
+ combinator_checks = [
635
805
('S K K' , 'I' ),
636
806
('K I' , 'π' ),
637
807
('K (S K K)' , 'π' ),
@@ -700,17 +870,44 @@ def check() -> int:
700
870
('B (a b) c' , 'D a b c' ),
701
871
]
702
872
ec = 0
703
- for testinput , expected in [(key , key ) for key in KNOWN_COMBINATORS ] + checks :
873
+ print ('Combinator checks' )
874
+ for testinput , expected in [(key , key ) for key in KNOWN_COMBINATORS ] + combinator_checks :
704
875
resexpr = from_string (testinput )
705
876
res = to_string (resexpr )
706
877
if res != expected :
707
878
if expected in KNOWN_COMBINATORS :
708
- print (f'❌ { testinput } → { res } { resexpr } but { expected } { from_string (expected )} = { KNOWN_COMBINATORS [expected ]} expected' )
879
+ print (f'❌ { testinput } ⇒ { res } { resexpr } but { expected } { from_string (expected )} = { KNOWN_COMBINATORS [expected ]} expected' )
709
880
else :
710
- print (f'❌ { testinput } → { res } { resexpr } but { expected } { from_string (expected )} expected' )
881
+ print (f'❌ { testinput } ⇒ { res } { resexpr } but { expected } { from_string (expected )} expected' )
882
+ ec = 1
883
+ else :
884
+ print (f'✅ { testinput } ⇒ { res } ' )
885
+
886
+ signature_checks = [
887
+ ('B' , '(a → b) → (c → a) → c → b' ),
888
+ ('C' , '(a → b → c) → b → a → c' ),
889
+ ('C*' , '(a → b → c → d) → a → c → b → d' ),
890
+ ('I' , 'a → a' ),
891
+ ('I*' , '(a → b) → a → b' ),
892
+ ('K' , 'a → b → a' ),
893
+ ('Ψ' , '(a → a → b) → (c → a) → c → c → b' ),
894
+ ('Q' , '(a → b) → (b → c) → a → c' ),
895
+ ('R' , 'a → (b → a → c) → b → c' ),
896
+ ('R*' , '(a → b → c → d) → c → a → b → d' ),
897
+ ('S' , '(a → b → c) → (a → b) → a → c' ),
898
+ ('Φ' , '(a → b → c) → (d → a) → (d → b) → d → c' ),
899
+ ('T' , 'a → (a → b) → b' ),
900
+ ('W' , '(a → a → b) → a → b' ),
901
+ ]
902
+ print ('\n Signature checks' )
903
+ for testinput , expected in signature_checks :
904
+ resexpr = from_string (testinput )
905
+ res = to_typesig (resexpr )
906
+ if res != expected :
907
+ print (f'❌ { testinput } ⇒ { res } but { expected } expected' )
711
908
ec = 1
712
909
else :
713
- print (f'✅ { testinput } → { res } ' )
910
+ print (f'✅ { testinput } ⇒ { res } ' )
714
911
return ec
715
912
716
913
0 commit comments