Created
October 2, 2024 00:30
-
-
Save raeeceip/6bae49c3d807bc2bc27e25c38a543573 to your computer and use it in GitHub Desktop.
run with python enhanced_inject_code.py path/to/your/file.py '{"function_name": [{"code": "print(\"Injection 1\")", "position": "start"}, {"code": "print(\"Injection 2\")", "position": "end"}]}' --dry-run , to test and try editing injecting functions in a specific line
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import astor | |
import sys | |
import os | |
import argparse | |
import json | |
from typing import List, Dict, Tuple | |
from enum import Enum | |
class InjectionPosition(Enum): | |
START = 'start' | |
END = 'end' | |
BEFORE_RETURN = 'before_return' | |
class InjectionVisitor(ast.NodeTransformer): | |
def __init__(self, injections: Dict[str, List[Tuple[str, InjectionPosition]]]): | |
self.injections = injections | |
self.modified = False | |
def visit_FunctionDef(self, node): | |
if node.name in self.injections: | |
for injection, position in self.injections[node.name]: | |
injection_ast = ast.parse(injection).body | |
if position == InjectionPosition.START: | |
node.body = injection_ast + node.body | |
elif position == InjectionPosition.END: | |
node.body = node.body + injection_ast | |
elif position == InjectionPosition.BEFORE_RETURN: | |
new_body = [] | |
for stmt in node.body: | |
if isinstance(stmt, ast.Return): | |
new_body.extend(injection_ast) | |
new_body.append(stmt) | |
node.body = new_body | |
self.modified = True | |
return node | |
def inject_code(file_path: str, injections: Dict[str, List[Tuple[str, InjectionPosition]]], dry_run: bool = False) -> Tuple[bool, str]: | |
with open(file_path, 'r') as file: | |
source = file.read() | |
tree = ast.parse(source) | |
visitor = InjectionVisitor(injections) | |
modified_tree = visitor.visit(tree) | |
if visitor.modified: | |
modified_source = astor.to_source(modified_tree) | |
if not dry_run: | |
with open(file_path, 'w') as file: | |
file.write(modified_source) | |
return True, modified_source | |
return False, source | |
def git_commit(file_path: str, message: str, dry_run: bool = False): | |
if dry_run: | |
print(f"[Dry run] Would git add {file_path}") | |
print(f"[Dry run] Would git commit -m '{message}'") | |
else: | |
os.system(f'git add {file_path}') | |
os.system(f'git commit -m "{message}"') | |
def parse_injections(injections_json: str) -> Dict[str, List[Tuple[str, InjectionPosition]]]: | |
raw_injections = json.loads(injections_json) | |
parsed_injections = {} | |
for func, injects in raw_injections.items(): | |
parsed_injections[func] = [] | |
for inject in injects: | |
code = inject['code'] | |
position = InjectionPosition(inject.get('position', 'start')) | |
parsed_injections[func].append((code, position)) | |
return parsed_injections | |
def main(): | |
parser = argparse.ArgumentParser(description="Inject code into Python files.") | |
parser.add_argument('file_path', help="Path to the Python file to modify") | |
parser.add_argument('injections', help="JSON string of injections") | |
parser.add_argument('--dry-run', action='store_true', help="Perform a dry run without making changes") | |
args = parser.parse_args() | |
try: | |
injections = parse_injections(args.injections) | |
except json.JSONDecodeError: | |
print("Error: Invalid JSON format for injections") | |
sys.exit(1) | |
except ValueError as e: | |
print(f"Error: {str(e)}") | |
sys.exit(1) | |
modified, new_source = inject_code(args.file_path, injections, args.dry_run) | |
if modified: | |
if args.dry_run: | |
print("Dry run: The following changes would be made:") | |
print(new_source) | |
else: | |
print(f"Code successfully injected into {args.file_path}") | |
# Commit the changes | |
relative_path = os.path.relpath(args.file_path) | |
functions = ', '.join(injections.keys()) | |
commit_message = f"Inject code into functions: {functions} ({relative_path})" | |
git_commit(relative_path, commit_message, args.dry_run) | |
if args.dry_run: | |
print(f"[Dry run] Would commit changes: {commit_message}") | |
else: | |
print(f"Changes committed: {commit_message}") | |
else: | |
print("No injections were made") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment