#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.9"
# dependencies = []
# ///


import os
import re
import sys

SUBS = [
    # builtins
    ("async def", "def"),
    ("async with", "with"),
    ("await ", ""),
    ("async for", "for"),
    ("AsyncIterator", "Iterator"),
    ("AsyncIterable", "Iterable"),
    ("__aiter__", "__iter__"),
    ("AsyncMock", "MagicMock"),
    ("assert_awaited_once", "assert_called_once"),
    (r"Awaitable\[([^\]]+)\]", r"\1"),
    # our public API
    ("AsyncCacheProxy", "SyncCacheProxy"),
    ("AsyncBaseStorage", "SyncBaseStorage"),
    ("AsyncCacheClient", "SyncCacheClient"),
    ("AsyncSqliteStorage", "SyncSqliteStorage"),
    ("anysqlite", "sqlite3"),
    ("aiter_stream", "iter_stream"),
    ("aiter_raw", "iter_raw"),
    ("aprint_sqlite_state", "print_sqlite_state"),
    ("make_async_iterator", "make_sync_iterator"),
    ("AsyncCacheTransport", "SyncCacheTransport"),
    (
        "hishel._core._storages._async_base",
        "hishel._core._storages._sync_base",
    ),
    # Third-party libraries
    ("AsyncClient", "Client"),
    ("@pytest.mark.anyio", ""),
    ("aread", "read"),
    ("aclose", "close"),
    ("handle_async_request", "handle_request"),
    ("AsyncBaseTransport", "BaseTransport"),
    ("AsyncHTTPTransport", "HTTPTransport"),
]
COMPILED_SUBS = [(re.compile(regex), repl) for regex, repl in SUBS]

USED_SUBS = set()


def unasync_line(line):
    for index, (regex, repl) in enumerate(COMPILED_SUBS):
        old_line = line
        line = re.sub(regex, repl, line)
        if index not in USED_SUBS:
            if line != old_line:
                USED_SUBS.add(index)
    return line


def unasync_file(in_path, out_path):
    with open(in_path) as in_file:
        with open(out_path, "w", newline="") as out_file:
            for line in in_file.readlines():
                line = unasync_line(line)
                out_file.write(line)


def unasync_file_check(in_path, out_path):
    with open(in_path) as in_file:
        with open(out_path) as out_file:
            for in_line, out_line in zip(in_file.readlines(), out_file.readlines()):
                expected = unasync_line(in_line)
                if out_line != expected:
                    print(f"unasync mismatch between {in_path!r} and {out_path!r}")
                    print(f"Async code:         {in_line!r}")
                    print(f"Expected sync code: {expected!r}")
                    print(f"Actual sync code:   {out_line!r}")
                    sys.exit(1)


def unasync_dir(in_dir, out_dir, check_only=False):
    for dirpath, dirnames, filenames in os.walk(in_dir):
        for filename in filenames:
            if not filename.endswith(".py"):
                continue
            rel_dir = os.path.relpath(dirpath, in_dir)
            in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename))
            out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename))
            print(in_path, "->", out_path)
            if check_only:
                unasync_file_check(in_path, out_path)
            else:
                unasync_file(in_path, out_path)


def main():
    check_only = "--check" in sys.argv

    FILES = [
        (
            "tests/_core/_async/test_sqlite_storage.py",
            "tests/_core/_sync/test_sqlite_storage.py",
        ),
        ("tests/test_async_httpx.py", "tests/test_sync_httpx.py"),
        ("hishel/_async_cache.py", "hishel/_sync_cache.py"),
        ("hishel/_core/_storages/_async_base.py", "hishel/_core/_storages/_sync_base.py"),
        ("hishel/_core/_storages/_async_sqlite.py", "hishel/_core/_storages/_sync_sqlite.py"),
        ("hishel/_async_httpx.py", "hishel/_sync_httpx.py"),
    ]

    for in_path, out_path in FILES:
        if check_only:
            unasync_file_check(in_path, out_path)
        else:
            unasync_file(in_path, out_path)
            print(f"Wrote {out_path}")

    if len(USED_SUBS) != len(SUBS):
        unused_subs = [SUBS[i] for i in range(len(SUBS)) if i not in USED_SUBS]

        from pprint import pprint

        print("This SUBS was not used")
        pprint(unused_subs)


if __name__ == "__main__":
    main()
