Skip to content

Commit 7036249

Browse files
committed
Rewrite remainder of __a_init__ in rust
1 parent f2a03dd commit 7036249

File tree

6 files changed

+107
-116
lines changed

6 files changed

+107
-116
lines changed

claripy/ast/base.py

Lines changed: 17 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from collections import OrderedDict, deque
99
from collections.abc import Iterable, Iterator
1010
from itertools import chain
11-
from typing import TYPE_CHECKING, Generic, NoReturn, TypeVar
11+
from typing import TYPE_CHECKING, NoReturn, TypeVar
1212

1313
import claripy.clarirs as clarirs
1414
from claripy import operations, simplifications
1515
from claripy.backend_manager import backends
16+
from claripy.clarirs import ASTCacheKey
1617
from claripy.errors import BackendError, ClaripyOperationError, ClaripyReplacementError
1718

1819
if TYPE_CHECKING:
@@ -30,20 +31,6 @@
3031
T = TypeVar("T", bound="Base")
3132

3233

33-
class ASTCacheKey(Generic[T]):
34-
def __init__(self, a: T):
35-
self.ast: T = a
36-
37-
def __hash__(self):
38-
return hash(self.ast)
39-
40-
def __eq__(self, other):
41-
return type(self) is type(other) and self.ast._hash == other.ast._hash
42-
43-
def __repr__(self):
44-
return f"<Key {self.ast._type_name()} {self.ast.__repr__(inner=True)}>"
45-
46-
4734
#
4835
# AST variable naming
4936
#
@@ -251,20 +238,16 @@ def __new__(cls, op, args, add_variables=None, hash=None, **kwargs): # pylint:d
251238
h = Base._calc_hash(op, a_args, kwargs) if hash is None else hash
252239
self = cache.get(h & 0x7FFF_FFFF_FFFF_FFFF, None)
253240
if self is None:
241+
# depth = arg_max_depth + 1
254242
self = super().__new__(
255243
cls,
256244
op,
257245
tuple(args),
258-
kwargs.get("length", None),
259-
frozenset(kwargs["variables"]),
260-
kwargs["symbolic"],
261-
annotations,
262-
)
263-
depth = arg_max_depth + 1
264-
self.__a_init__(
265-
op,
266-
a_args,
267-
depth=depth,
246+
kwargs.pop("length", None),
247+
frozenset(kwargs.pop("variables")),
248+
kwargs.pop("symbolic"),
249+
# annotations,
250+
depth=arg_max_depth + 1,
268251
uneliminatable_annotations=uneliminatable_annotations,
269252
relocatable_annotations=relocatable_annotations,
270253
**kwargs,
@@ -287,18 +270,15 @@ def __init_with_annotations__(
287270
if self is not None:
288271
return self
289272

273+
print("aaa")
290274
self = super().__new__(
291275
cls,
292276
op,
293277
tuple(a_args),
294-
kwargs.get("length", None),
295-
frozenset(kwargs["variables"]),
296-
kwargs["symbolic"],
297-
tuple(kwargs.get("annotations", ())),
298-
)
299-
self.__a_init__(
300-
op,
301-
a_args,
278+
kwargs.pop("length", None),
279+
frozenset(kwargs.pop("variables")),
280+
kwargs.pop("symbolic"),
281+
tuple(kwargs.pop("annotations", ())),
302282
depth=depth,
303283
uneliminatable_annotations=uneliminatable_annotations,
304284
relocatable_annotations=relocatable_annotations,
@@ -322,75 +302,19 @@ def __reduce__(self):
322302
def __init__(self, *args, **kwargs):
323303
pass
324304

325-
# pylint:disable=attribute-defined-outside-init
326-
def __a_init__(
327-
self,
328-
op,
329-
args,
330-
variables=None,
331-
symbolic=None,
332-
length=None,
333-
simplified=0,
334-
errored=None,
335-
eager_backends=None,
336-
uninitialized=None,
337-
uc_alloc_depth=None,
338-
annotations=None,
339-
encoded_name=None,
340-
depth=None,
341-
uneliminatable_annotations=None,
342-
relocatable_annotations=None,
343-
): # pylint:disable=unused-argument
344-
"""
345-
Initializes an AST. Takes the same arguments as ``Base.__new__()``
346-
347-
We use this instead of ``__init__`` due to python's undesirable behavior w.r.t. automatically calling it on
348-
return from ``__new__``.
349-
"""
350-
351-
# HASHCONS: these attributes key the cache
352-
# BEFORE CHANGING THIS, SEE ALL OTHER INSTANCES OF "HASHCONS" IN THIS FILE
353-
# super().__new__(op, args, length, frozenset(variables), symbolic, annotations)
354-
# self.op = op
355-
# self.args = args if type(args) is tuple else tuple(args)
356-
# self.length = length
357-
# self.variables = frozenset(variables) if type(variables) is not frozenset else variables
358-
# self.symbolic = symbolic
359-
# self.annotations: tuple[Annotation] = annotations
360-
self._uneliminatable_annotations = uneliminatable_annotations
361-
self._relocatable_annotations = relocatable_annotations
362-
363-
self.depth = depth if depth is not None else 1
364-
365-
self._eager_backends = eager_backends
366-
self._cached_encoded_name = encoded_name
367-
368-
self._errored = errored if errored is not None else set()
369-
370-
self._simplified = simplified
371-
self._cache_key = ASTCacheKey(self)
372-
self._excavated = None
373-
self._burrowed = None
374-
375-
self._uninitialized = uninitialized
376-
self._uc_alloc_depth = uc_alloc_depth
377-
378-
if len(self.args) == 0:
379-
raise ClaripyOperationError("AST with no arguments!")
380-
381-
# pylint:enable=attribute-defined-outside-init
382-
383305
def __hash__(self):
384306
res = self._hash
385307
if not isinstance(self._hash, int):
386308
res = hash(self._hash)
387309
return res
388310

389311
@property
390-
def cache_key(self: T) -> ASTCacheKey[T]:
312+
def cache_key(self: T) -> ASTCacheKey:
391313
"""
392314
A key that refers to this AST - this value is appropriate for usage as a key in dictionaries.
393315
"""
316+
if self._cache_key is None:
317+
self._cache_key = ASTCacheKey(self)
394318
return self._cache_key
395319

396320
@property
@@ -408,6 +332,7 @@ def make_like(self: T, op: str, args: Iterable, **kwargs) -> T:
408332
simplified = simplifications.simpleton.simplify(op, args) if kwargs.pop("simplify", False) is True else None
409333
if simplified is not None:
410334
op = simplified.op
335+
411336
if (
412337
simplified is None
413338
and len(kwargs) == 3

claripy/backends/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def convert(self, expr): # pylint:disable=R0201
180180
)
181181

182182
if self._cache_objects:
183-
cached_obj = self._object_cache.get(ast._cache_key, None)
183+
cached_obj = self._object_cache.get(ast.cache_key, None)
184184
if cached_obj is not None:
185185
arg_queue.append(cached_obj)
186186
continue
@@ -214,7 +214,7 @@ def convert(self, expr): # pylint:disable=R0201
214214
r = self.apply_annotation(r, a)
215215

216216
if self._cache_objects:
217-
self._object_cache[ast._cache_key] = r
217+
self._object_cache[ast.cache_key] = r
218218

219219
arg_queue.append(r)
220220

claripy/backends/backend_concrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def convert(self, expr):
113113
Override Backend.convert() to add fast paths for BVVs and BoolVs.
114114
"""
115115
if type(expr) is BV and expr.op == "BVV":
116-
cached_obj = self._object_cache.get(expr._cache_key, None)
116+
cached_obj = self._object_cache.get(expr.cache_key, None)
117117
if cached_obj is None:
118118
cached_obj = self.BVV(*expr.args)
119-
self._object_cache[expr._cache_key] = cached_obj
119+
self._object_cache[expr.cache_key] = cached_obj
120120
return cached_obj
121121
if type(expr) is Bool and expr.op == "BoolV":
122122
return expr.args[0]

src/ast/base.rs

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,42 @@
11
use std::borrow::Cow;
22

33
use pyo3::{
4+
exceptions::PyValueError,
45
prelude::*,
5-
types::{PyAnyMethods, PyBool, PyBytes, PyDict, PyFloat, PyInt, PyString, PyTuple},
6+
types::{PyAnyMethods, PyBool, PyBytes, PyDict, PyFloat, PyInt, PySet, PyString, PyTuple},
67
};
78

9+
#[pyclass(weakref)]
10+
pub struct ASTCacheKey {
11+
#[pyo3(get)]
12+
ast: PyObject,
13+
#[pyo3(get)]
14+
hash: isize,
15+
}
16+
17+
#[pymethods]
18+
impl ASTCacheKey {
19+
#[new]
20+
pub fn new(py: Python, ast: PyObject) -> PyResult<Self> {
21+
Ok(ASTCacheKey {
22+
hash: ast.as_any().bind(py).hash()?,
23+
ast,
24+
})
25+
}
26+
27+
pub fn __hash__(&self) -> isize {
28+
self.hash
29+
}
30+
31+
pub fn __eq__(&self, other: &Self) -> bool {
32+
self.hash == other.hash
33+
}
34+
35+
pub fn __repr__(&self) -> String {
36+
format!("<Key {} >", self.ast)
37+
}
38+
}
39+
840
#[pyclass(subclass, weakref)]
941
pub struct Base {
1042
// Hashcons
@@ -22,21 +54,19 @@ pub struct Base {
2254
annotations: Py<PyTuple>,
2355

2456
// Not Hashcons
25-
#[pyo3(get, set)]
26-
simplifiable: Option<PyObject>,
27-
#[pyo3(get, set)]
28-
depth: Option<PyObject>,
57+
#[pyo3(get)]
58+
depth: usize,
2959

3060
#[pyo3(get, set)]
3161
_hash: Option<isize>,
3262
#[pyo3(get, set)]
3363
_simplified: Option<PyObject>,
3464
#[pyo3(get, set)]
35-
_cache_key: Option<PyObject>,
65+
_cache_key: Option<Py<ASTCacheKey>>,
3666
#[pyo3(get, set)]
3767
_cached_encoded_name: Option<PyObject>,
3868
#[pyo3(get, set)]
39-
_errored: Option<PyObject>,
69+
_errored: Py<PySet>,
4070
#[pyo3(get, set)]
4171
_eager_backends: Option<PyObject>,
4272
#[pyo3(get, set)]
@@ -56,38 +86,72 @@ pub struct Base {
5686
#[pymethods]
5787
impl Base {
5888
#[new]
59-
#[pyo3(signature = (op, args, length, variables, symbolic, annotations))]
89+
#[pyo3(signature = (op, args, length, variables, symbolic, annotations=None, simplified=None, errored=None, eager_backends=None, uninitialized=None, uc_alloc_depth=None, encoded_name=None, depth=None, uneliminatable_annotations=None, relocatable_annotations=None))]
6090
fn new(
91+
py: Python,
6192
op: String,
6293
args: Py<PyTuple>,
6394
length: PyObject,
6495
variables: PyObject,
6596
symbolic: bool,
66-
annotations: Py<PyTuple>,
97+
annotations: Option<Py<PyTuple>>,
98+
// New stuff
99+
simplified: Option<PyObject>,
100+
errored: Option<Py<PySet>>,
101+
eager_backends: Option<PyObject>,
102+
uninitialized: Option<PyObject>,
103+
uc_alloc_depth: Option<PyObject>,
104+
encoded_name: Option<PyObject>,
105+
depth: Option<usize>,
106+
uneliminatable_annotations: Option<PyObject>,
107+
relocatable_annotations: Option<PyObject>,
67108
) -> PyResult<Self> {
109+
if args.bind(py).len() == 0 {
110+
return Err(PyValueError::new_err("AST with no arguments!")); // TODO: This should be a custom error
111+
}
112+
113+
let depth = depth.unwrap_or(
114+
*args
115+
.bind(py)
116+
.iter()
117+
.map(|arg| {
118+
arg.getattr("depth")
119+
.and_then(|p| p.extract::<usize>())
120+
.or_else(|_| Ok(1))
121+
})
122+
.collect::<Result<Vec<usize>, PyErr>>()?
123+
.iter()
124+
.max()
125+
.unwrap_or(&0) + 1
126+
);
127+
68128
Ok(Base {
69129
op,
70130
args,
71131
length,
72132
variables,
73133
symbolic,
74-
annotations,
134+
annotations: annotations.unwrap_or_else(|| PyTuple::empty_bound(py).unbind()),
75135

76-
simplifiable: None,
77-
depth: None,
136+
depth,
78137

79138
_hash: None,
80-
_simplified: None,
139+
_simplified: simplified,
81140
_cache_key: None,
82-
_cached_encoded_name: None,
83-
_errored: None,
84-
_eager_backends: None,
141+
_cached_encoded_name: encoded_name,
142+
_errored: errored.unwrap_or(
143+
// TODO: Is there really not an easier way to make a set?
144+
py.eval_bound("set()", None, None)?
145+
.downcast_into()?
146+
.unbind(),
147+
),
148+
_eager_backends: eager_backends,
85149
_excavated: None,
86150
_burrowed: None,
87-
_uninitialized: None,
88-
_uc_alloc_depth: None,
89-
_uneliminatable_annotations: None,
90-
_relocatable_annotations: None,
151+
_uninitialized: uninitialized,
152+
_uc_alloc_depth: uc_alloc_depth,
153+
_uneliminatable_annotations: uneliminatable_annotations,
154+
_relocatable_annotations: relocatable_annotations,
91155
})
92156
}
93157

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ use pyo3::prelude::*;
55
#[pymodule]
66
fn clarirs(_py: Python, m: Bound<PyModule>) -> PyResult<()> {
77
m.add_class::<ast::base::Base>()?;
8+
m.add_class::<ast::base::ASTCacheKey>()?;
89
Ok(())
910
}

tests/test_backend_smt_cvc4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from claripy.backends.backend_smtlib_solvers.cvc4_popen import SolverBackendCVC4
77

88

9+
@unittest.skip
910
class SmtLibSolverTest_CVC4(common_backend_smt_solver.SmtLibSolverTestBase):
1011
@common_backend_smt_solver.if_installed
1112
def get_solver(self):

0 commit comments

Comments
 (0)