#!/usr/bin/python3

"""filter-tarball - filter tarball allowing only known good licensing.

Usage: filter-tarball <orig-tarball>

Read the original tarball 'foo-1.2.3.tar.gz' and write a new tarball
'foo-1.2.3-filtered.tar.xz', containing only the files listed
with a license in the 'file-licensing' file.

Example:
    filter-tarball avrdude-7.2.tar.gz

    Produces a filtered tarball version as 'avrdude-7.2-filtered.tar.gz'.

"""


import io
import sys
import tarfile
from pathlib import Path
from pprint import pprint


def load_license_map(license_file_path: Path) -> dict[str, str]:
    """Load license map from file."""
    license_map = {}

    with open(license_file_path) as licf:
        for line in licf.readlines():
            assert line[-1] == "\n"
            line = line[:-1]
            if not line:
                continue
            if line[0] == "#":
                continue
            items = line.split("\t")
            if len(items) == 1:
                continue
            assert len(items) in [2, 3]
            lic_str = items[0]
            fname = items[1]
            # ignore remarks in items[2]
            pprint(items)
            if lic_str:
                license_map[fname] = lic_str
        print

    pprint(license_map)
    print
    return license_map


README_TEXT = """\
About this filtered tarball
===========================

This tarball was created from the original upstream tarball for Fedora
by removing some files due to potential licensing issues.

The original tarball is available somewhere on

    https://github.com/avrdudes/avrdude
"""


def run_filter(
    orig_tarfile: tarfile.TarFile,
    filtered_tarfile: tarfile.TarFile,
    tarbase: str,
    license_map: dict[str, str],
):
    """Filter orig tarball to filtered tarball."""
    for finfo in orig_tarfile:
        match finfo.type:
            case tarfile.REGTYPE:
                assert finfo.name.startswith(f"{tarbase}/")
                rel_fname = finfo.name.replace(f"{tarbase}/", "", 1)
                if rel_fname not in license_map:
                    print("[ERR!]", finfo.name, "not in license map")
                    sys.exit(1)
                match license_map[rel_fname]:
                    case "x":
                        print("[EXCL]", finfo.name)
                    case _:
                        filtered_tarfile.addfile(
                            finfo, orig_tarfile.extractfile(finfo)
                        )
                        print("[COPY]", finfo.name, license_map[rel_fname])
            case tarfile.DIRTYPE:
                if finfo.name.startswith(f"{tarbase}/atmel-docs"):
                    print("[SKIP]", finfo.name)
                else:
                    print(" [DIR]", finfo.name)
                    filtered_tarfile.addfile(finfo)
            case _:
                raise ValueError(f"Unhandled TarInfo type: {finfo}")


def add_readme_file(filtered_tarfile: tarfile.TarFile, tarbase: str):
    """Add README file to tarball describing filtered tarball."""
    readme_fname = f"{tarbase}/README.filtered-tarball"
    print(" [ADD]", readme_fname, "(remarks on filtered tarball)")
    readme_info = tarfile.TarInfo(readme_fname)
    readme_bytes = io.BytesIO(README_TEXT.encode("utf-8"))
    filtered_tarfile.addfile(readme_info, readme_bytes)


def main(argv=None):
    """Run main program."""
    if argv is None:
        argv = sys.argv[1:]
    assert len(argv) == 1
    orig_tarball = Path(argv[0])
    assert orig_tarball.is_file()
    assert orig_tarball.match("*.tar.gz")
    tarbase = orig_tarball.name[: -len(".tar.gz")]
    filtered_tarball = orig_tarball.parent / f"{tarbase}-filtered.tar.xz"
    assert not filtered_tarball.exists()

    license_map = load_license_map(Path(__file__).parent / "file-licensing")

    try:
        with tarfile.open(orig_tarball) as orig_tarfile, tarfile.open(
            filtered_tarball, "w:xz"
        ) as filtered_tarfile:
            run_filter(orig_tarfile, filtered_tarfile, tarbase, license_map)
            add_readme_file(filtered_tarfile, tarbase)
    except BaseException as err:
        print("Removing incomplete tarball file:", filtered_tarball)
        filtered_tarball.unlink(missing_ok=True)
        raise err


if __name__ == "__main__":
    main()
