Skip to content

Commit 14cea3a

Browse files
authored
Move SI information from BVS args to annotation (#520)
* Move SI information from BVS args to annotation * Add docstrings * Refactor BackendVSA convert and abstract
1 parent 9d1176f commit 14cea3a

File tree

10 files changed

+134
-89
lines changed

10 files changed

+134
-89
lines changed

claripy/annotation.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from __future__ import annotations
22

3-
import claripy
3+
from typing import TYPE_CHECKING
4+
45
from claripy.errors import ClaripyOperationError
56

7+
if TYPE_CHECKING:
8+
import claripy
9+
610

711
class Annotation:
812
"""
@@ -64,6 +68,8 @@ def relocate(self, src: claripy.ast.Base, dst: claripy.ast.Base): # pylint:disa
6468

6569

6670
class SimplificationAvoidanceAnnotation(Annotation):
71+
"""SimplificationAvoidanceAnnotation is an annotation that prevents simplification of an AST."""
72+
6773
@property
6874
def eliminatable(self):
6975
return False
@@ -73,6 +79,33 @@ def relocatable(self):
7379
return False
7480

7581

82+
class StridedIntervalAnnotation(SimplificationAvoidanceAnnotation):
83+
"""StridedIntervalAnnotation allows annotating a BVS to represent a strided interval."""
84+
85+
stride: int
86+
lower_bound: int
87+
upper_bound: int
88+
89+
def __init__(self, stride: int, lower_bound: int, upper_bound: int):
90+
self.stride = stride
91+
self.lower_bound = lower_bound
92+
self.upper_bound = upper_bound
93+
94+
def __hash__(self):
95+
return hash((self.stride, self.lower_bound, self.upper_bound))
96+
97+
def __eq__(self, other):
98+
return (
99+
isinstance(other, StridedIntervalAnnotation)
100+
and self.stride == other.stride
101+
and self.lower_bound == other.lower_bound
102+
and self.upper_bound == other
103+
)
104+
105+
def __repr__(self):
106+
return f"<StridedIntervalAnnotation {self.stride}:{self.lower_bound} - {self.upper_bound}>"
107+
108+
76109
class RegionAnnotation(SimplificationAvoidanceAnnotation):
77110
"""
78111
Use RegionAnnotation to annotate ASTs. Normally, an AST annotated by
@@ -84,12 +117,6 @@ def __init__(self, region_id, region_base_addr, offset):
84117
self.region_base_addr = region_base_addr
85118
self.offset = offset
86119

87-
# Do necessary conversion here
88-
if isinstance(self.region_base_addr, claripy.ast.Base):
89-
self.region_base_addr = claripy.backends.vsa.convert(self.region_base_addr)
90-
if isinstance(self.offset, claripy.ast.Base):
91-
self.offset = claripy.backends.vsa.convert(self.offset)
92-
93120
#
94121
# Overriding base methods
95122
#

claripy/ast/base.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
ArgType = Union["Base", bool, int, float, str, FSort, tuple["ArgType"], None]
3636

3737
T = TypeVar("T", bound="Base")
38+
A = TypeVar("A", bound="Annotation")
3839

3940

4041
class ASTCacheKey(Generic[T]):
@@ -578,7 +579,7 @@ def has_annotation_type(self, annotation_type: type[Annotation]) -> bool:
578579
"""
579580
return any(isinstance(a, annotation_type) for a in self.annotations)
580581

581-
def get_annotations_by_type(self, annotation_type: type[Annotation]) -> tuple[Annotation, ...]:
582+
def get_annotations_by_type(self, annotation_type: type[A]) -> tuple[A, ...]:
582583
"""
583584
Get all annotations of a given type.
584585
@@ -587,6 +588,18 @@ def get_annotations_by_type(self, annotation_type: type[Annotation]) -> tuple[An
587588
"""
588589
return tuple(a for a in self.annotations if isinstance(a, annotation_type))
589590

591+
def get_annotation(self, annotation_type: type[A]) -> A | None:
592+
"""
593+
Get the first annotation of a given type.
594+
595+
:param annotation_type: The type of the annotation.
596+
:return: The annotation of the given type, or None if not found.
597+
"""
598+
for a in self.annotations:
599+
if isinstance(a, annotation_type):
600+
return a
601+
return None
602+
590603
def append_annotation(self, a: Annotation) -> Self:
591604
"""
592605
Appends an annotation to this AST.
@@ -766,17 +779,7 @@ def _op_repr(
766779
):
767780
if details < ReprLevel.FULL_REPR:
768781
if op == "BVS":
769-
extras = []
770-
if args[1] is not None:
771-
fmt = "%#x" if isinstance(args[1], int) else "%s"
772-
extras.append("min=%s" % (fmt % args[1]))
773-
if args[2] is not None:
774-
fmt = "%#x" if isinstance(args[2], int) else "%s"
775-
extras.append("max=%s" % (fmt % args[2]))
776-
if args[3] is not None:
777-
fmt = "%#x" if isinstance(args[3], int) else "%s"
778-
extras.append("stride=%s" % (fmt % args[3]))
779-
return "{}{}".format(args[0], "{{{}}}".format(", ".join(extras)) if extras else "")
782+
return f"{args[0]}"
780783

781784
if op == "BoolV":
782785
return str(args[0])

claripy/ast/bv.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,6 @@ def identical(self, other: Self, strict=False) -> bool:
199199
def BVS( # pylint:disable=redefined-builtin
200200
name,
201201
size,
202-
min=None,
203-
max=None,
204-
stride=None,
205202
explicit_name=None,
206203
**kwargs,
207204
) -> BV:
@@ -213,17 +210,11 @@ def BVS( # pylint:disable=redefined-builtin
213210
214211
:param name: The name of the symbol.
215212
:param size: The size (in bits) of the bit-vector.
216-
:param min: The minimum value of the symbol, used only for value-set analysis
217-
:param max: The maximum value of the symbol, used only for value-set analysis
218-
:param stride: The stride of the symbol, used only for value-set analysis
219213
:param bool explicit_name: If False, an identifier is appended to the name to ensure uniqueness.
220214
221215
:returns: a BV object representing this symbol.
222216
"""
223217

224-
if stride == 0 and max != min:
225-
raise ClaripyValueError("BVSes of stride 0 should have max == min")
226-
227218
if isinstance(name, bytes):
228219
name = name.decode()
229220
if not isinstance(name, str):
@@ -234,7 +225,7 @@ def BVS( # pylint:disable=redefined-builtin
234225

235226
return BV(
236227
"BVS",
237-
(n, min, max, stride),
228+
(n,),
238229
variables=frozenset((n,)),
239230
length=size,
240231
symbolic=True,
@@ -301,16 +292,11 @@ def SI(
301292
si = claripy.backends.backend_vsa.CreateStridedInterval(
302293
name=name, bits=bits, lower_bound=lower_bound, upper_bound=upper_bound, stride=stride, to_conv=to_conv
303294
)
304-
return BVS(
305-
name, si._bits, min=si._lower_bound, max=si._upper_bound, stride=si._stride, explicit_name=explicit_name
295+
return BVS(name, si._bits, explicit_name=explicit_name).annotate(
296+
claripy.annotation.StridedIntervalAnnotation(si._stride, si._lower_bound, si._upper_bound)
306297
)
307-
return BVS(
308-
name,
309-
bits,
310-
min=lower_bound,
311-
max=upper_bound,
312-
stride=stride,
313-
explicit_name=explicit_name,
298+
return BVS(name, bits, explicit_name=explicit_name).annotate(
299+
claripy.annotation.StridedIntervalAnnotation(stride, lower_bound, upper_bound)
314300
)
315301

316302

@@ -331,6 +317,8 @@ def ValueSet(bits, region=None, region_base_addr=None, value=None, name=None, va
331317
region_base_addr = 0
332318

333319
v = region_base_addr + value
320+
if isinstance(v, claripy.ast.Base):
321+
v = claripy.simplify(v)
334322

335323
# Backward compatibility
336324
if isinstance(v, numbers.Number):
@@ -340,19 +328,25 @@ def ValueSet(bits, region=None, region_base_addr=None, value=None, name=None, va
340328
min_v, max_v = v.lower_bound, v.upper_bound
341329
stride = v.stride
342330
elif isinstance(v, claripy.ast.Base):
343-
sv = claripy.simplify(v)
344-
if sv.op == "BVS":
345-
min_v = sv.args[1]
346-
max_v = sv.args[2]
347-
stride = sv.args[3]
331+
si_anno = v.get_annotation(claripy.annotation.StridedIntervalAnnotation)
332+
if si_anno is not None:
333+
min_v = si_anno.lower_bound
334+
max_v = si_anno.upper_bound
335+
stride = si_anno.stride
336+
elif v.op == "BVV":
337+
min_v = v.args[0]
338+
max_v = v.args[0]
339+
stride = 0
348340
else:
349-
raise ClaripyValueError(f"ValueSet() does not take `value` ast with op {sv.op}")
341+
raise ClaripyValueError(f"ValueSet() does not take `value` ast with op {v.op}")
350342
else:
351343
raise ClaripyValueError(f"ValueSet() does not take `value` of type {type(value)}")
352344

353345
if name is None:
354346
name = "ValueSet"
355-
bvs = BVS(name, bits, min=region_base_addr + min_v, max=region_base_addr + max_v, stride=stride)
347+
bvs = BVS(name, bits).annotate(
348+
claripy.annotation.StridedIntervalAnnotation(stride, region_base_addr + min_v, region_base_addr + max_v)
349+
)
356350

357351
# Annotate the bvs and return the new AST
358352
return bvs.annotate(claripy.annotation.RegionAnnotation(region, region_base_addr, value))

claripy/backends/backend_vsa/backend_vsa.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import operator
77
from functools import reduce
88

9-
from claripy.annotation import RegionAnnotation
9+
from claripy.annotation import RegionAnnotation, StridedIntervalAnnotation
1010
from claripy.ast.base import Base
11-
from claripy.ast.bv import ESI, SI, TSI
11+
from claripy.ast.bv import BV, BVV, ESI, SI, TSI, VS
1212
from claripy.backends.backend import Backend
13+
from claripy.backends.backend_vsa.errors import ClaripyVSAError
1314
from claripy.balancer import Balancer
1415
from claripy.errors import BackendError
1516
from claripy.operations import backend_operations_vsa_compliant, expression_set_operations
@@ -128,13 +129,28 @@ def _abstract(self, e):
128129
return TSI(e.bits, explicit_name=e.name)
129130
if e.is_bottom:
130131
return ESI(e.bits)
132+
if e.stride in {0, 1} and e.lower_bound == e.upper_bound:
133+
return BVV(e.lower_bound, e.bits)
131134
return SI(
132135
name=e.name,
133136
bits=e.bits,
134137
lower_bound=e.lower_bound,
135138
upper_bound=e.upper_bound,
136139
stride=e.stride,
137140
)
141+
if isinstance(e, ValueSet):
142+
if len(e.regions) == 0:
143+
return VS(bits=e.bits, name=e.name)
144+
if len(e.regions) == 1:
145+
region = next(iter(e.regions))
146+
return VS(
147+
bits=e.bits,
148+
region=region,
149+
region_base_addr=e._region_base_addrs[region].eval(1)[0] if e._region_base_addrs else 0,
150+
value=e.regions[region].eval(1)[0],
151+
name=e.name,
152+
)
153+
raise ClaripyVSAError("Cannot abstract ValueSet with multiple regions")
138154
raise BackendError(f"Don't know how to abstract {type(e)}")
139155

140156
def _eval(self, expr, n, extra_constraints=(), solver=None, model_callback=None):
@@ -230,18 +246,43 @@ def apply_annotation(self, o, a):
230246
:rtype: BackendObject
231247
"""
232248

233-
# Currently we only support RegionAnnotation
234-
235-
if not isinstance(a, RegionAnnotation):
236-
return o
237-
238-
if not isinstance(o, ValueSet):
239-
# Convert it to a ValueSet first
240-
# Note that the original value is not kept at all. If you want to convert a StridedInterval to a ValueSet,
241-
# you gotta do the conversion by calling AST.annotate() from outside.
242-
o = ValueSet.empty(o.bits)
249+
if isinstance(o, StridedInterval):
250+
if isinstance(a, StridedIntervalAnnotation):
251+
return CreateStridedInterval(
252+
bits=o.bits,
253+
stride=a.stride,
254+
lower_bound=a.lower_bound,
255+
upper_bound=a.upper_bound,
256+
name=o.name,
257+
)
258+
259+
if isinstance(a, RegionAnnotation):
260+
offset = self.convert(a.offset)
261+
if isinstance(offset, numbers.Number):
262+
offset = StridedInterval(bits=o.bits, stride=0, lower_bound=offset, upper_bound=offset)
263+
vs = ValueSet.empty(o.bits)
264+
if isinstance(offset, StridedInterval):
265+
vs._merge_si(a.region_id, a.region_base_addr, offset)
266+
elif isinstance(offset, ValueSet):
267+
for si in offset.regions.values():
268+
vs._merge_si(a.region_id, a.region_base_addr, si)
269+
else:
270+
raise ClaripyVSAError(f"Unsupported offset type {type(offset)}")
271+
return vs
272+
273+
if isinstance(o, ValueSet) and isinstance(a, StridedIntervalAnnotation):
274+
si = CreateStridedInterval(
275+
bits=o.bits,
276+
stride=a.stride,
277+
lower_bound=a.lower_bound,
278+
upper_bound=a.upper_bound,
279+
name=o.name,
280+
)
281+
vs = o.copy()
282+
vs._merge_si(a.region_id, a.region_base_addr, si)
283+
return vs
243284

244-
return o.apply_annotation(a)
285+
raise ValueError(f"Unsupported annotation type {type(a)} for object {type(o)}")
245286

246287
@staticmethod
247288
def BVV(ast):
@@ -302,16 +343,8 @@ def SGE(a, b):
302343
return a.SGE(b)
303344

304345
@staticmethod
305-
def BVS(ast: Base):
306-
size = ast.size()
307-
name, mn, mx, stride = ast.args
308-
return CreateStridedInterval(
309-
name=name,
310-
bits=size,
311-
lower_bound=mn,
312-
upper_bound=mx,
313-
stride=stride,
314-
)
346+
def BVS(ast: BV):
347+
return CreateStridedInterval(name=ast.args[0], bits=ast.size())
315348

316349
def If(self, cond, t, f):
317350
if not self.has_true(cond):

claripy/backends/backend_vsa/valueset.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -209,19 +209,6 @@ def get_si(self, region):
209209
def stridedinterval(self):
210210
return self._si
211211

212-
def apply_annotation(self, annotation):
213-
"""
214-
Apply a new annotation onto self, and return a new ValueSet object.
215-
216-
:param RegionAnnotation annotation: The annotation to apply.
217-
:return: A new ValueSet object
218-
:rtype: ValueSet
219-
"""
220-
221-
vs = self.copy()
222-
vs._merge_si(annotation.region_id, annotation.region_base_addr, annotation.offset)
223-
return vs
224-
225212
def __repr__(self):
226213
s = ""
227214
for region, si in self._regions.items():

claripy/balancer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _replacements_iter(self):
5252
min_int = 0
5353
mn = self._lower_bounds.get(k, min_int)
5454
mx = self._upper_bounds.get(k, max_int)
55-
bound_si = BVS("bound", len(k.ast), min=mn, max=mx)
55+
bound_si = BVS("bound", len(k.ast)).annotate(claripy.annotation.StridedIntervalAnnotation(1, mn, mx))
5656
l.debug("Yielding bound %s for %s.", bound_si, k.ast)
5757
if k.ast.op == "Reverse":
5858
yield (k.ast.args[0], k.ast.intersection(bound_si).reversed)

tests/test_expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_expression(self):
180180

181181
def test_cardinality(self):
182182
x = claripy.BVS("x", 32)
183-
y = claripy.BVS("y", 32, min=100, max=120)
183+
y = claripy.BVS("y", 32).annotate(claripy.annotation.StridedIntervalAnnotation(1, 100, 120))
184184
n = claripy.BVV(10, 32)
185185
m = claripy.BVV(20, 32)
186186

tests/test_simplify.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55

66
import claripy
7+
import claripy.annotation
78

89

910
class TestSimplify(unittest.TestCase):
@@ -56,7 +57,7 @@ def assert_correct(a, b):
5657
assert_correct(x % y, claripy.backends.z3.simplify(x % y))
5758

5859
def test_rotate_shift_mask_simplification(self):
59-
a = claripy.BVS("N", 32, max=0xC, min=0x1)
60+
a = claripy.BVS("N", 32).annotate(claripy.annotation.StridedIntervalAnnotation(1, 0x1, 0xC))
6061
extend_ = claripy.BVS("extend", 32, uninitialized=True)
6162
a_ext = extend_.concat(a)
6263
expr = ((a_ext << 3) | (claripy.LShR(a_ext, 61))) & 0x7FFFFFFF8

0 commit comments

Comments
 (0)