"""Helpers for implementing generic IR to IR transforms.""" from __future__ import annotations from typing import Final, Optional from mypyc.ir.ops import ( Assign, AssignMulti, BasicBlock, Box, Branch, Call, CallC, Cast, ComparisonOp, DecRef, Extend, FloatComparisonOp, FloatNeg, FloatOp, GetAttr, GetElementPtr, Goto, IncRef, InitStatic, IntOp, KeepAlive, LoadAddress, LoadErrorValue, LoadGlobal, LoadLiteral, LoadMem, LoadStatic, MethodCall, Op, OpVisitor, PrimitiveOp, RaiseStandardError, Return, SetAttr, SetElement, SetMem, Truncate, TupleGet, TupleSet, Unborrow, Unbox, Unreachable, Value, ) from mypyc.irbuild.ll_builder import LowLevelIRBuilder class IRTransform(OpVisitor[Optional[Value]]): """Identity transform. Subclass and override to perform changes to IR. Subclass IRTransform and override any OpVisitor visit_* methods that perform any IR changes. The default implementations implement an identity transform. A visit method can return None to remove ops. In this case the transform must ensure that no op uses the original removed op as a source after the transform. You can retain old BasicBlock and op references in ops. The transform will automatically patch these for you as needed. """ def __init__(self, builder: LowLevelIRBuilder) -> None: self.builder = builder # Subclasses add additional op mappings here. A None value indicates # that the op/register is deleted. self.op_map: dict[Value, Value | None] = {} def transform_blocks(self, blocks: list[BasicBlock]) -> None: """Transform basic blocks that represent a single function. The result of the transform will be collected at self.builder.blocks. """ block_map: dict[BasicBlock, BasicBlock] = {} op_map = self.op_map empties = set() for block in blocks: new_block = BasicBlock() block_map[block] = new_block self.builder.activate_block(new_block) new_block.error_handler = block.error_handler for op in block.ops: new_op = op.accept(self) if new_op is not op: op_map[op] = new_op # A transform can produce empty blocks which can be removed. if is_empty_block(new_block) and not is_empty_block(block): empties.add(new_block) self.builder.blocks = [block for block in self.builder.blocks if block not in empties] # Update all op/block references to point to the transformed ones. patcher = PatchVisitor(op_map, block_map) for block in self.builder.blocks: for op in block.ops: op.accept(patcher) if block.error_handler is not None: block.error_handler = block_map.get(block.error_handler, block.error_handler) def add(self, op: Op) -> Value: return self.builder.add(op) def visit_goto(self, op: Goto) -> None: self.add(op) def visit_branch(self, op: Branch) -> None: self.add(op) def visit_return(self, op: Return) -> None: self.add(op) def visit_unreachable(self, op: Unreachable) -> None: self.add(op) def visit_assign(self, op: Assign) -> Value | None: if op.src in self.op_map and self.op_map[op.src] is None: # Special case: allow removing register initialization assignments return None return self.add(op) def visit_assign_multi(self, op: AssignMulti) -> Value | None: return self.add(op) def visit_load_error_value(self, op: LoadErrorValue) -> Value | None: return self.add(op) def visit_load_literal(self, op: LoadLiteral) -> Value | None: return self.add(op) def visit_get_attr(self, op: GetAttr) -> Value | None: return self.add(op) def visit_set_attr(self, op: SetAttr) -> Value | None: return self.add(op) def visit_load_static(self, op: LoadStatic) -> Value | None: return self.add(op) def visit_init_static(self, op: InitStatic) -> Value | None: return self.add(op) def visit_tuple_get(self, op: TupleGet) -> Value | None: return self.add(op) def visit_tuple_set(self, op: TupleSet) -> Value | None: return self.add(op) def visit_inc_ref(self, op: IncRef) -> Value | None: return self.add(op) def visit_dec_ref(self, op: DecRef) -> Value | None: return self.add(op) def visit_call(self, op: Call) -> Value | None: return self.add(op) def visit_method_call(self, op: MethodCall) -> Value | None: return self.add(op) def visit_cast(self, op: Cast) -> Value | None: return self.add(op) def visit_box(self, op: Box) -> Value | None: return self.add(op) def visit_unbox(self, op: Unbox) -> Value | None: return self.add(op) def visit_raise_standard_error(self, op: RaiseStandardError) -> Value | None: return self.add(op) def visit_call_c(self, op: CallC) -> Value | None: return self.add(op) def visit_primitive_op(self, op: PrimitiveOp) -> Value | None: return self.add(op) def visit_truncate(self, op: Truncate) -> Value | None: return self.add(op) def visit_extend(self, op: Extend) -> Value | None: return self.add(op) def visit_load_global(self, op: LoadGlobal) -> Value | None: return self.add(op) def visit_int_op(self, op: IntOp) -> Value | None: return self.add(op) def visit_comparison_op(self, op: ComparisonOp) -> Value | None: return self.add(op) def visit_float_op(self, op: FloatOp) -> Value | None: return self.add(op) def visit_float_neg(self, op: FloatNeg) -> Value | None: return self.add(op) def visit_float_comparison_op(self, op: FloatComparisonOp) -> Value | None: return self.add(op) def visit_load_mem(self, op: LoadMem) -> Value | None: return self.add(op) def visit_set_mem(self, op: SetMem) -> Value | None: return self.add(op) def visit_get_element_ptr(self, op: GetElementPtr) -> Value | None: return self.add(op) def visit_set_element(self, op: SetElement) -> Value | None: return self.add(op) def visit_load_address(self, op: LoadAddress) -> Value | None: return self.add(op) def visit_keep_alive(self, op: KeepAlive) -> Value | None: return self.add(op) def visit_unborrow(self, op: Unborrow) -> Value | None: return self.add(op) class PatchVisitor(OpVisitor[None]): def __init__( self, op_map: dict[Value, Value | None], block_map: dict[BasicBlock, BasicBlock] ) -> None: self.op_map: Final = op_map self.block_map: Final = block_map def fix_op(self, op: Value) -> Value: new = self.op_map.get(op, op) assert new is not None, "use of removed op" return new def fix_block(self, block: BasicBlock) -> BasicBlock: return self.block_map.get(block, block) def visit_goto(self, op: Goto) -> None: op.label = self.fix_block(op.label) def visit_branch(self, op: Branch) -> None: op.value = self.fix_op(op.value) op.true = self.fix_block(op.true) op.false = self.fix_block(op.false) def visit_return(self, op: Return) -> None: op.value = self.fix_op(op.value) def visit_unreachable(self, op: Unreachable) -> None: pass def visit_assign(self, op: Assign) -> None: op.src = self.fix_op(op.src) def visit_assign_multi(self, op: AssignMulti) -> None: op.src = [self.fix_op(s) for s in op.src] def visit_load_error_value(self, op: LoadErrorValue) -> None: pass def visit_load_literal(self, op: LoadLiteral) -> None: pass def visit_get_attr(self, op: GetAttr) -> None: op.obj = self.fix_op(op.obj) def visit_set_attr(self, op: SetAttr) -> None: op.obj = self.fix_op(op.obj) op.src = self.fix_op(op.src) def visit_load_static(self, op: LoadStatic) -> None: pass def visit_init_static(self, op: InitStatic) -> None: op.value = self.fix_op(op.value) def visit_tuple_get(self, op: TupleGet) -> None: op.src = self.fix_op(op.src) def visit_tuple_set(self, op: TupleSet) -> None: op.items = [self.fix_op(item) for item in op.items] def visit_inc_ref(self, op: IncRef) -> None: op.src = self.fix_op(op.src) def visit_dec_ref(self, op: DecRef) -> None: op.src = self.fix_op(op.src) def visit_call(self, op: Call) -> None: op.args = [self.fix_op(arg) for arg in op.args] def visit_method_call(self, op: MethodCall) -> None: op.obj = self.fix_op(op.obj) op.args = [self.fix_op(arg) for arg in op.args] def visit_cast(self, op: Cast) -> None: op.src = self.fix_op(op.src) def visit_box(self, op: Box) -> None: op.src = self.fix_op(op.src) def visit_unbox(self, op: Unbox) -> None: op.src = self.fix_op(op.src) def visit_raise_standard_error(self, op: RaiseStandardError) -> None: if isinstance(op.value, Value): op.value = self.fix_op(op.value) def visit_call_c(self, op: CallC) -> None: op.args = [self.fix_op(arg) for arg in op.args] def visit_primitive_op(self, op: PrimitiveOp) -> None: op.args = [self.fix_op(arg) for arg in op.args] def visit_truncate(self, op: Truncate) -> None: op.src = self.fix_op(op.src) def visit_extend(self, op: Extend) -> None: op.src = self.fix_op(op.src) def visit_load_global(self, op: LoadGlobal) -> None: pass def visit_int_op(self, op: IntOp) -> None: op.lhs = self.fix_op(op.lhs) op.rhs = self.fix_op(op.rhs) def visit_comparison_op(self, op: ComparisonOp) -> None: op.lhs = self.fix_op(op.lhs) op.rhs = self.fix_op(op.rhs) def visit_float_op(self, op: FloatOp) -> None: op.lhs = self.fix_op(op.lhs) op.rhs = self.fix_op(op.rhs) def visit_float_neg(self, op: FloatNeg) -> None: op.src = self.fix_op(op.src) def visit_float_comparison_op(self, op: FloatComparisonOp) -> None: op.lhs = self.fix_op(op.lhs) op.rhs = self.fix_op(op.rhs) def visit_load_mem(self, op: LoadMem) -> None: op.src = self.fix_op(op.src) def visit_set_mem(self, op: SetMem) -> None: op.dest = self.fix_op(op.dest) op.src = self.fix_op(op.src) def visit_get_element_ptr(self, op: GetElementPtr) -> None: op.src = self.fix_op(op.src) def visit_set_element(self, op: SetElement) -> None: op.src = self.fix_op(op.src) def visit_load_address(self, op: LoadAddress) -> None: if isinstance(op.src, LoadStatic): new = self.fix_op(op.src) assert isinstance(new, LoadStatic), new op.src = new def visit_keep_alive(self, op: KeepAlive) -> None: op.src = [self.fix_op(s) for s in op.src] def visit_unborrow(self, op: Unborrow) -> None: op.src = self.fix_op(op.src) def is_empty_block(block: BasicBlock) -> bool: return len(block.ops) == 1 and isinstance(block.ops[0], Unreachable)