|
6 | 6 | import operator
|
7 | 7 | from functools import reduce
|
8 | 8 |
|
9 |
| -from claripy.annotation import RegionAnnotation |
| 9 | +from claripy.annotation import RegionAnnotation, StridedIntervalAnnotation |
10 | 10 | from claripy.ast.base import Base
|
11 |
| -from claripy.ast.bv import ESI, SI, TSI |
| 11 | +from claripy.ast.bv import BV, BVV, ESI, SI, TSI, VS |
12 | 12 | from claripy.backends.backend import Backend
|
| 13 | +from claripy.backends.backend_vsa.errors import ClaripyVSAError |
13 | 14 | from claripy.balancer import Balancer
|
14 | 15 | from claripy.errors import BackendError
|
15 | 16 | from claripy.operations import backend_operations_vsa_compliant, expression_set_operations
|
@@ -128,13 +129,28 @@ def _abstract(self, e):
|
128 | 129 | return TSI(e.bits, explicit_name=e.name)
|
129 | 130 | if e.is_bottom:
|
130 | 131 | return ESI(e.bits)
|
| 132 | + if e.stride in {0, 1} and e.lower_bound == e.upper_bound: |
| 133 | + return BVV(e.lower_bound, e.bits) |
131 | 134 | return SI(
|
132 | 135 | name=e.name,
|
133 | 136 | bits=e.bits,
|
134 | 137 | lower_bound=e.lower_bound,
|
135 | 138 | upper_bound=e.upper_bound,
|
136 | 139 | stride=e.stride,
|
137 | 140 | )
|
| 141 | + if isinstance(e, ValueSet): |
| 142 | + if len(e.regions) == 0: |
| 143 | + return VS(bits=e.bits, name=e.name) |
| 144 | + if len(e.regions) == 1: |
| 145 | + region = next(iter(e.regions)) |
| 146 | + return VS( |
| 147 | + bits=e.bits, |
| 148 | + region=region, |
| 149 | + region_base_addr=e._region_base_addrs[region].eval(1)[0] if e._region_base_addrs else 0, |
| 150 | + value=e.regions[region].eval(1)[0], |
| 151 | + name=e.name, |
| 152 | + ) |
| 153 | + raise ClaripyVSAError("Cannot abstract ValueSet with multiple regions") |
138 | 154 | raise BackendError(f"Don't know how to abstract {type(e)}")
|
139 | 155 |
|
140 | 156 | def _eval(self, expr, n, extra_constraints=(), solver=None, model_callback=None):
|
@@ -230,18 +246,43 @@ def apply_annotation(self, o, a):
|
230 | 246 | :rtype: BackendObject
|
231 | 247 | """
|
232 | 248 |
|
233 |
| - # Currently we only support RegionAnnotation |
234 |
| - |
235 |
| - if not isinstance(a, RegionAnnotation): |
236 |
| - return o |
237 |
| - |
238 |
| - if not isinstance(o, ValueSet): |
239 |
| - # Convert it to a ValueSet first |
240 |
| - # Note that the original value is not kept at all. If you want to convert a StridedInterval to a ValueSet, |
241 |
| - # you gotta do the conversion by calling AST.annotate() from outside. |
242 |
| - o = ValueSet.empty(o.bits) |
| 249 | + if isinstance(o, StridedInterval): |
| 250 | + if isinstance(a, StridedIntervalAnnotation): |
| 251 | + return CreateStridedInterval( |
| 252 | + bits=o.bits, |
| 253 | + stride=a.stride, |
| 254 | + lower_bound=a.lower_bound, |
| 255 | + upper_bound=a.upper_bound, |
| 256 | + name=o.name, |
| 257 | + ) |
| 258 | + |
| 259 | + if isinstance(a, RegionAnnotation): |
| 260 | + offset = self.convert(a.offset) |
| 261 | + if isinstance(offset, numbers.Number): |
| 262 | + offset = StridedInterval(bits=o.bits, stride=0, lower_bound=offset, upper_bound=offset) |
| 263 | + vs = ValueSet.empty(o.bits) |
| 264 | + if isinstance(offset, StridedInterval): |
| 265 | + vs._merge_si(a.region_id, a.region_base_addr, offset) |
| 266 | + elif isinstance(offset, ValueSet): |
| 267 | + for si in offset.regions.values(): |
| 268 | + vs._merge_si(a.region_id, a.region_base_addr, si) |
| 269 | + else: |
| 270 | + raise ClaripyVSAError(f"Unsupported offset type {type(offset)}") |
| 271 | + return vs |
| 272 | + |
| 273 | + if isinstance(o, ValueSet) and isinstance(a, StridedIntervalAnnotation): |
| 274 | + si = CreateStridedInterval( |
| 275 | + bits=o.bits, |
| 276 | + stride=a.stride, |
| 277 | + lower_bound=a.lower_bound, |
| 278 | + upper_bound=a.upper_bound, |
| 279 | + name=o.name, |
| 280 | + ) |
| 281 | + vs = o.copy() |
| 282 | + vs._merge_si(a.region_id, a.region_base_addr, si) |
| 283 | + return vs |
243 | 284 |
|
244 |
| - return o.apply_annotation(a) |
| 285 | + raise ValueError(f"Unsupported annotation type {type(a)} for object {type(o)}") |
245 | 286 |
|
246 | 287 | @staticmethod
|
247 | 288 | def BVV(ast):
|
@@ -302,16 +343,8 @@ def SGE(a, b):
|
302 | 343 | return a.SGE(b)
|
303 | 344 |
|
304 | 345 | @staticmethod
|
305 |
| - def BVS(ast: Base): |
306 |
| - size = ast.size() |
307 |
| - name, mn, mx, stride = ast.args |
308 |
| - return CreateStridedInterval( |
309 |
| - name=name, |
310 |
| - bits=size, |
311 |
| - lower_bound=mn, |
312 |
| - upper_bound=mx, |
313 |
| - stride=stride, |
314 |
| - ) |
| 346 | + def BVS(ast: BV): |
| 347 | + return CreateStridedInterval(name=ast.args[0], bits=ast.size()) |
315 | 348 |
|
316 | 349 | def If(self, cond, t, f):
|
317 | 350 | if not self.has_true(cond):
|
|
0 commit comments