We can implement a simple bottom-up rewriter from below, which gives a term s, a function f, and the term twill replace each fapplication f(r_1, ..., r_n)with sa t[r_1, ..., r_n]. I use the notation t[r_1, ..., r_n]to denote a term obtained by replacing free variables in twith members r_1, ..., r_n.
The rewriter can be implemented by the Z3 API. I use AstMapto cache results and a list todoto store expressions that still need to be processed.
Here is a simple example that replaces f-applications of a form f(t)with g(t+1)in s.
x = Var(0, IntSort())
print rewrite(s, f, g(x + 1))
Here is the code and some examples. Beware, I just checked the code in a small set of examples.
from z3 import *
def update_term(t, args):
n = len(args)
_args = (Ast * n)()
for i in range(n):
_args[i] = args[i].as_ast()
return z3._to_expr_ref(Z3_update_term(t.ctx_ref(), t.as_ast(), n, _args), t.ctx)
def rewrite(s, f, t):
"""
Replace f-applications f(r_1, ..., r_n) with t[r_1, ..., r_n] in s.
"""
todo = []
todo.append(s)
cache = AstMap(ctx=s.ctx)
while todo:
n = todo[len(todo) - 1]
if is_var(n):
todo.pop()
cache[n] = n
elif is_app(n):
visited = True
new_args = []
for i in range(n.num_args()):
arg = n.arg(i)
if not arg in cache:
todo.append(arg)
visited = False
else:
new_args.append(cache[arg])
if visited:
todo.pop()
g = n.decl()
if eq(g, f):
new_n = substitute_vars(t, *new_args)
else:
new_n = update_term(n, new_args)
cache[n] = new_n
else:
assert(is_quantifier(n))
b = n.body()
if b in cache:
todo.pop()
new_n = update_term(n, [ cache[b] ])
cache[n] = new_n
else:
todo.append(b)
return cache[s]
f = Function('f', IntSort(), IntSort())
a, b = Ints('a b')
s = Or(f(a) == 0, f(a) == 1, f(a+a) == 2)
print rewrite(s, f, b)
g = Function('g', IntSort(), IntSort())
x = Var(0, IntSort())
print rewrite(s, f, g(x + 1))
f = Function('f', IntSort(), IntSort(), IntSort())
g = Function('g', IntSort(), IntSort(), IntSort())
s = Or(f(a, f(a, b)) == 0, f(b, a) == 1, f(f(1,0), 0) == 2)
y = Var(1, IntSort())
print rewrite(s, f, g(y, x))
s = ForAll([a], f(a, b) >= 0)
print rewrite(s, f, g(y, x + 1))