#!/opt/starfish/examples/venv/bin/python3
"""
***********************************************************************************************************

 Starfish Storage Corporation ("Starfish") CONFIDENTIAL
 Unpublished Copyright (c) 2011 - present Starfish Storage Corporation, All Rights Reserved.

 NOTICE: This file and its contents (1) constitute Starfish's "External Code" under Starfish's most-recent
 Limited Software End-User License Agreement, and (2) is and remains the property of Starfish. The
 intellectual and technical concepts contained herein are proprietary to Starfish and may be covered by
 U.S. and/or foreign patents or patents in process, and are protected by trade secret or copyright law.
 Dissemination of this information or reproduction of this material is strictly forbidden unless prior
 written permission is obtained from Starfish. Access to the source code contained herein is hereby
 forbidden to anyone except (A) current Starfish employees, managers, or contractors who have executed
 confidentiality or nondisclosure agreements explicitly covering such access, and (B) licensees of
 Starfish's software.

 ANY REPRODUCTION, COPYING, MODIFICATION, DISTRIBUTION, PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR
 THROUGH USE OF THIS SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF STARFISH IS STRICTLY PROHIBITED
 AND IS IN VIOLATION OF APPLICABLE LAWS AND INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS
 FILE OR ITS CONTENTS AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS TO REPRODUCE,
 DISCLOSE, OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN
 WHOLE OR IN PART.

 FOR U.S. GOVERNMENT CUSTOMERS REGARDING THIS DOCUMENTATION/SOFTWARE
   These notices shall be marked on any reproduction of this data, in whole or in part.
   NOTICE: Notwithstanding any other lease or license that may pertain to, or accompany the delivery of,
   this computer software, the rights of the Government regarding its use, reproduction and disclosure are
   as set forth in Section 52.227-19 of the FARS Computer Software-Restricted Rights clause.
   RESTRICTED RIGHTS NOTICE: Use, duplication, or disclosure by the Government is subject to the
   restrictions as set forth in subparagraph (c)(1)(ii) of the Rights in Technical Data and Computer
   Software clause at DFARS 52.227-7013.

***********************************************************************************************************
"""

###############################################################################
#  Author Doug Hughes
#  Last modified 2023-08-17
#
# Compare directory trees using SQL wrapped with a little Python
#
# WARNING: This script runs queries directly against the Starfish database
# and may use up memory or become outdated if there is a database schema change
#
# 2019-11-05: change the sub-expression to glob on directory% instead of directory/% to find file matches at top level
# 2019-11-06: add --name-only option and eliminate leading / from paths
# 2020-09-28: add --compare-sha1
# 2020-12-08: normalize by removing leading / from any volpath
# 2021-04-03: adjust treediff query so that it's blazing fast
# 2021-04-05: add itersize to prevent output from using up all memory when getting big query result
# 2021-04-06: use right instead of replace
# 2021-04-09: add case-insensitive and common modes for matches
# 2021-04-12: add lots of unit tests for case insensitive and name only and intersection modes
# 2021-04-23: fix tree top of branch issues with comparisons (special case and slash handling) - more/fixed self tests
# 2021-07-27: fix comparisons and test cases for top of volume comparisons
# 2022-08-18: Refactor and fix pylint warning: 'redefined-outer-name'
# 2023-06-02: improve sql by 500x faster by eliminating CTE and directly calling volume_id and ancestor
# 2023-08-17: add check for missing or misspelled volume name

import argparse
import configparser
import os
import re
import sys
import time
import unittest

import psycopg2


class TestQ(unittest.TestCase):
    """extension for unittest framework."""

    def setUp(self):
        pass

    @staticmethod
    def make_vols():
        """create test volumes to run unit tests"""
        # create test volume 1
        trymkdir("/tmp/s1")
        trymkdir("/tmp/s1/doug")
        trymkdir("/tmp/s1/doug/dir1")
        trymkdir("/tmp/s1/doug/dir1/dir2")  # vol1 only
        trymkdir("/tmp/s1/bob")
        trymkdir("/tmp/s1/ben")
        with open("/tmp/s1/doug/file1", "w") as f:
            print("hello", file=f)
        with open("/tmp/s1/doug/file2", "w") as f:
            print("hello", file=f)  # different size
        with open("/tmp/s1/doug/file3", "w") as f:
            print("hello", file=f)  # different contents same size
        with open("/tmp/s1/doug/file4", "w") as f:
            print("hello", file=f)  # case insensitive file match

        # create test volume 2
        trymkdir("/tmp/s2")
        trymkdir("/tmp/s2/doug")
        trymkdir("/tmp/s2/doug/dir1")
        trymkdir("/tmp/s2/doug/dir1/dir3")  # vol2 only
        trymkdir("/tmp/s2/bob")  # same dir
        trymkdir("/tmp/s2/Ben")  # case insensitive test difference
        with open("/tmp/s2/doug/file1", "w") as f:
            print("hello", file=f)
        with open("/tmp/s2/doug/file2", "w") as f:
            print("hello2", file=f)  # different size
        with open("/tmp/s2/doug/file3", "w") as f:
            print("hella", file=f)  # different contents same size
        with open("/tmp/s2/doug/File4", "w") as f:
            print("hello", file=f)  # case insensitive file match

        # do these the trivial, naive way for expediency
        os.system("sf volume add s1 /tmp/s1 > /dev/null 2>&1")
        os.system("sf volume add s2 /tmp/s2 > /dev/null 2>&1")
        time.sleep(5)
        os.system("sf scan start --wait s1 >/dev/null 2>&1")
        os.system("sf scan start --wait s2 >/dev/null 2>&1")

    def test_auth(self):
        """test connection"""
        conn = psycopg2.connect(getpgauth())
        self.assertNotEqual(conn, None)

    def test_query(self):
        """test sql return"""
        conn = psycopg2.connect(getpgauth())
        query = """select count(*) from sf_volumes.volume"""
        cur = conn.cursor()
        cur.execute(query)
        self.assertNotEqual(cur, None)
        rows = cur.fetchall()
        self.assertEqual(len(rows), 1)

    def test_volpath_validate(self):
        """normalizing volume path"""
        self.assertCountEqual(validate_volpath("home:this"), ["home", "this"])
        self.assertCountEqual(validate_volpath("home:this/"), ["home", "this"])
        self.assertCountEqual(validate_volpath("home:/this/"), ["home", "this"])
        self.assertCountEqual(validate_volpath("home:/this//"), ["home", "this"])
        self.assertCountEqual(validate_volpath("home:/this///"), ["home", "this"])
        self.assertCountEqual(validate_volpath("home:/this/that"), ["home", "this/that"])
        with self.assertRaises(SystemExit):
            validate_volpath("home")

    def check_stdout(self, cliargs, query, expected):
        """static method to intercept and compare stdout from the main function"""
        from io import StringIO  # pylint: disable=C0415

        saved_stdout = sys.stdout
        try:
            out = StringIO()
            sys.stdout = out
            fetch_rows(cliargs, query)
            output = out.getvalue().rstrip()
            # expecting this
            print(f"output: f{output}")
            print(f"expected: f{expected}")
            # self.assertEqual(output, expected)
            self.assertCountEqual(output.split("\n"), expected.split("\n"))
        finally:
            sys.stdout = saved_stdout

        if cliargs.debug:
            print(f"output = {output}")

    @staticmethod
    def build_cliargs(**kwargs):
        """build the cliargs since it's very repetitive. Pass in the thing to change"""
        cliargs = argparse.Namespace
        cliargs.left_only = False
        cliargs.compare_sha1 = False
        cliargs.name_only = False
        cliargs.right_only = False
        cliargs.intersection = False
        cliargs.insensitive = False
        cliargs.nowarn = True
        cliargs.quiet = True
        cliargs.csv = False
        cliargs.debug = 0
        cliargs.delimiter = " "
        cliargs.VOLPATH1 = "s1:"
        cliargs.VOLPATH2 = "s2:"

        for k, v in kwargs.items():
            setattr(cliargs, k, v)
        return cliargs

    # these cover all the major test cases

    def test_left_only(self):
        """test compare looking at s1 differences only"""
        cliargs = TestQ.build_cliargs(left_only=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        self.check_stdout(cliargs, q, "ben 1024\ndoug/dir1/dir2 1024\ndoug/file2 6\ndoug/file4 6")

    def test_left_only_name(self):
        """test compare looking at s1 differences only ignore size"""
        cliargs = TestQ.build_cliargs(left_only=True, name_only=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        self.check_stdout(cliargs, q, "ben\ndoug/dir1/dir2\ndoug/file4")

    def test_right_only(self):
        """test compare looking at s2 differences"""
        cliargs = TestQ.build_cliargs(right_only=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        self.check_stdout(cliargs, q, " Ben 1024\n doug/dir1/dir3 1024\n doug/file2 7\n doug/File4 6")

    def test_right_only_name(self):
        """test compare looking at s2 differences only ignore size"""
        cliargs = TestQ.build_cliargs(right_only=True, name_only=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        self.check_stdout(cliargs, q, " Ben\n doug/dir1/dir3\n doug/File4")

    def test_intersection_name_only(self):
        """test intersection (same file exists on both sides with name only, case sensitive"""
        cliargs = TestQ.build_cliargs(intersection=True, name_only=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        self.check_stdout(cliargs, q, "doug/file1\ndoug/file2\ndoug/file3\ndoug\nbob\ndoug/dir1")

    def test_intersection_name_only_ic(self):
        """test intersection (same file exists on both sides with name only, case insensitive"""
        cliargs = TestQ.build_cliargs(intersection=True, name_only=True, insensitive=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        self.check_stdout(cliargs, q, "doug/file4\ndoug/file1\ndoug/file2\ndoug/file3\ndoug\nbob\nben\ndoug/dir1")

    def test_intersection_ic(self):
        """test intersection (same file exists on both sides with sizes, case insensitive"""
        cliargs = TestQ.build_cliargs(intersection=True, insensitive=True)
        cliargs.VOLPATH2 = "s2:"
        q = build_query(cliargs)
        expected_output = (
            "ben 1024 ben 1024\nbob 1024 bob 1024\ndoug 1024 doug 1024\ndoug/dir1 1024 doug/dir1 1024\n"
            "doug/file1 6 doug/file1 6\ndoug/file3 6 doug/file3 6\ndoug/file4 6 doug/file4 6"
        )
        self.check_stdout(cliargs, q, expected_output)

    def test_identical(self):
        """test 2 identical things. No errors."""
        cliargs = TestQ.build_cliargs(intersection=False, insensitive=False)
        cliargs.VOLPATH2 = "s1:"
        q = build_query(cliargs)
        expected_output = ""
        with self.assertRaises(SystemExit) as cm:
            self.check_stdout(cliargs, q, expected_output)
        self.assertEqual(cm.exception.code, 0)


##################################################################################
# Global functions
##################################################################################


def trymkdir(path):
    """wrap each makedir/file in a try to make setup go faster"""
    try:
        os.makedirs(path)
    except Exception:  # pylint: disable=W0703
        # don't care if this fails; probably it already exists, only used in tests
        pass


def validate_volpath(volpath):
    """take a volume path like vol:path and split it into a volume and a path"""
    arr = [x.rstrip("/") for x in re.split(r":/?", volpath, maxsplit=1)]
    if len(arr) != 2:
        print("Volpath must be in the form of VOLUME:PATH", file=sys.stderr)
        sys.exit(1)
    return arr


def getpgauth():
    """pull auth info from config file to use implicitly"""
    try:
        if not os.access("/opt/starfish/etc/99-local.ini", os.R_OK):
            print(
                "No access to /opt/starfish/etc/99-local.ini. You may need to "
                "run this as Starfish user, as root, or add group read permissions to this file.",
                file=sys.stderr,
            )
            sys.exit(1)
        config = configparser.ConfigParser()
        config.read("/opt/starfish/etc/99-local.ini")
        return config.get("pg", "pg_uri")
    except OSError:
        print("can't read config file to get connection uri. check permissions.", file=sys.stderr)
        sys.exit(1)


def get_dir_id(volid, path):
    """get directory id instead of using an inefficient CTE - 20x faster"""
    query = f"select id from sf.dir_current where volume_id = {volid} AND path = '{path}'"

    rows, count = runq(query)  # pylint: disable=W0612
    if count == 0:
        print(f"path {path} does not exist. Check spelling.", file=sys.stderr)
        sys.exit(1)
    return rows[0][0]


def get_volume_id(vname):
    """get volume_id to optimize partition usage"""
    query = f"select id from sf_volumes.volume where name = '{vname}'"

    rows, count = runq(query)  # pylint: disable=W0612
    if count == 0:
        print(f"volume '{vname}' does not exist. Check spelling.", file=sys.stderr)
        sys.exit(1)
    return rows[0][0]


def runq(query):
    """get the cursor and execute the query; return the rows"""
    try:
        conn = psycopg2.connect(getpgauth())
    except psycopg2.DatabaseError as e:
        print("unable to connect to the database: {}".format(str(e)), file=sys.stderr)
        sys.exit(1)

    cur = conn.cursor()
    # don't run out of memory on large result sets
    cur.itersize = 100000
    if args.debug:
        print("executing query " + query, file=sys.stderr)

    cur.execute(query)
    rows = cur.fetchall()
    count = cur.rowcount
    if count == 0:
        print("no rows returned.", file=sys.stderr)
        return [[0]], 0

    return rows, count


def fetch_rows(args, query):  # pylint: disable=W0621
    """fetch and print rows for query"""
    delimiter = " "

    if args.csv:
        delimiter = ","
    elif args.delimiter:
        delimiter = args.delimiter

    rows, rowcount = runq(query)

    if not args.quiet:
        print("found %d" % rowcount, file=sys.stderr)  # noqa: S001

    if rowcount == 0:
        # shortcut
        sys.exit(0)

    for row in rows:
        if not args.name_only:
            if row[0] is None:
                print(args.delimiter + row[2] + args.delimiter + str(row[3]))
            elif row[2] is None:
                print(row[0] + args.delimiter + str(row[1]))
            else:
                print(delimiter.join(str(el) for el in row))
        else:
            # If this is an intersection, just print the one column
            if row[1] is None or args.intersection:
                print(row[0])
            elif row[0] is None:
                print(args.delimiter + row[1])
            else:
                print(delimiter.join(str(el) for el in row))


def build_query(args):  # pylint: disable=R0914,W0621
    """build the query string
    :returns: a SQL query"""

    if args.compare_sha1 and args.name_only:
        print("--compare-sha1 and --name-only are mutually exclusive")
        sys.exit(1)

    # modify the query to also compare sha1
    if args.compare_sha1:
        sha1_join = "LEFT JOIN sf.job_result_current j ON f.volume_id = j.volume_id AND f.id = j.fs_entry_id"
        sha1_where = "AND j.result ? 'sha1'"
        sha1_main_compare = "AND dt1.hashval = dt2.hashval"
        sha1_select = ", CAST(j.result->>'sha1' AS varchar) as hashval"
        # so that the directory query sha1 has a value in the union
        # that is guaranteed to be the same
        sha1_stub = ", 'aaaaa'"
    else:
        sha1_join = ""
        sha1_where = ""
        sha1_main_compare = ""
        sha1_select = ""
        sha1_stub = ""

    if not args.nowarn:
        print(
            "WARNING: This may use up a large amount of memory in Postgres if run\n"
            "against very large directory trees. \n"
            "Disable this warning with --nowarn",
            file=sys.stderr,
        )

    srcvol, srcpath = validate_volpath(args.VOLPATH1)
    dstvol, dstpath = validate_volpath(args.VOLPATH2)

    srcvolid = get_volume_id(srcvol)
    dstvolid = get_volume_id(dstvol)
    srcpathid = get_dir_id(srcvolid, srcpath)
    dstpathid = get_dir_id(dstvolid, dstpath)

    # in case insensitive mode we patch the query string to force paths to lower for comparison in the join
    if args.insensitive:
        CI_BEGIN = "lower("
        CI_END = ")"
    else:
        CI_BEGIN = ""
        CI_END = ""

    if args.intersection:
        JOIN_TYPE = "INNER"
        # suppress useless comparison of top level
        left_or_right_params = "WHERE p1 != ''"
    else:
        left_or_right_params = "WHERE p1 IS NULL OR p2 IS NULL"
        # Set outer join type
        JOIN_TYPE = "FULL OUTER"
        if args.left_only:
            JOIN_TYPE = "LEFT OUTER"
        elif args.right_only:
            JOIN_TYPE = "RIGHT OUTER"

    # name only is good to check if things exist at all on one side that aren't on the other (exclusive)
    # dt1 = left dir tree (1st position); dt2 = right
    if not args.name_only:
        query_add = ", f.size"
        dir_size_stub = ", 1024"
        main_select = f"""SELECT dt1.p1, dt1.size, dt2.p2, dt2.size
                         FROM dt1 {JOIN_TYPE} JOIN dt2 ON p1 = p2
                         AND dt1.size = dt2.size
                         {sha1_main_compare}
                         {left_or_right_params}
                      """  # noqa: S001
    else:
        query_add = ""
        # hard code a fake place holder value since a dir size may be strongly divergent between any 2 trees
        dir_size_stub = ""
        main_select = f"""SELECT dt1.p1, dt2.p2
                         FROM dt1 {JOIN_TYPE} JOIN dt2 ON p1 = p2
                         {left_or_right_params}
                      """  # noqa: S001

    # case to handle when we are comparing at top of a volume, using right(path,0) suppresses
    # the entire thing, which causes issues
    if srcpath == "":
        spath = "TRIM(LEADING '/' FROM d.path)"
    else:
        # make comparison relative to source path down in tree (trim off front of d.path)
        # so that we are comparing some thing on source and destination
        spath = f"right(d.path, -length('{srcpath}'))"
    if dstpath == "":
        dpath = "TRIM(LEADING '/' FROM d.path)"
    else:
        # make comparison relative to destination path down in tree (trim off front of d.path)
        # so that we are comparing some thing on source and destination
        dpath = f"right(d.path, -length('{dstpath}'))"

    # selectively merge in joins and wheres
    query = f"""
        SELECT {CI_BEGIN} TRIM(LEADING '/' FROM {spath} || '/' || f.name) {CI_END} AS p1{query_add}{sha1_select}
          INTO TEMPORARY TABLE dt1
            FROM sf.file_current f
            INNER JOIN sf.dir_current d ON d.id = f.parent_id AND d.volume_id = f.volume_id
            {sha1_join}
            WHERE d.volume_id = {srcvolid} AND d.ancestor_ids && array[{srcpathid}::BIGINT]
            {sha1_where}
          UNION ALL
          -- empty dirs
          SELECT {CI_BEGIN} TRIM(LEADING '/' FROM {spath}) {CI_END} AS p1{dir_size_stub}{sha1_stub}
          FROM sf.dir_current d
          WHERE d.ancestor_ids && array[{srcpathid}::BIGINT]
                AND d.volume_id = {srcvolid}
        ;
        analyze dt1;

        SELECT {CI_BEGIN} TRIM(LEADING '/' FROM {dpath}  || '/' || f.name) {CI_END} AS p2{query_add}{sha1_select}
          INTO TEMPORARY TABLE dt2
            FROM sf.file_current f
            INNER JOIN sf.dir_current d ON d.id = f.parent_id AND d.volume_id = f.volume_id
            {sha1_join}
            WHERE d.volume_id = {dstvolid} AND d.ancestor_ids && array[{dstpathid}::BIGINT]
            {sha1_where}
          UNION ALL
          -- empty dirs
          SELECT {CI_BEGIN} TRIM(LEADING '/' FROM {dpath}) {CI_END} AS p1{dir_size_stub}{sha1_stub}
          FROM sf.dir_current d
          WHERE d.ancestor_ids && array[{dstpathid}::BIGINT]
                AND d.volume_id = {dstvolid}
        ;
        analyze dt2;
        {main_select}
        """  # noqa: E501

    return query


# ################## main ##########################


def main():
    # Parse Arguments
    global args  # pylint: disable=W0601
    descr = (
        "this is similar to /opt/starfish/bin/cmp_file_trees "
        "but uses all Sql to do the comparison and prints out "
        "the results without a summary and in a traditional "
        "<comm> like format."
    )
    parser = argparse.ArgumentParser(description=descr)
    parser.add_argument("--csv", action="store_true", help="use comma separated values output")
    parser.add_argument("--delimiter", default="\t", help="use a delimiter of your choice in csv output")
    parser.add_argument("--debug", action="store_true", required=False, help="add some debugging to output")
    parser.add_argument("--verbose", action="store_true", required=False, help="print verbose status messages")
    parser.add_argument(
        "--nowarn", action="store_true", required=False, help="suppress warning message about memory use"
    )
    parser.add_argument(
        "--compare-sha1",
        action="store_true",
        required=False,
        help="compare the sha1 on source and dest in addition to size "
        "(requires running hash job on source and test)",
    )
    parser.add_argument(
        "--insensitive", "-i", action="store_true", required=False, help="do case insensitive name comparison"
    )
    parser.add_argument(
        "--intersection",
        action="store_true",
        required=False,
        help="instead of showing differences, show all the things in common",
    )
    parser.add_argument(
        "--quiet", action="store_true", required=False, help="suppress some output like number of found rows"
    )
    parser.add_argument("--make-test-vols", action="store_true", required=False, help=argparse.SUPPRESS)
    parser.add_argument("--test", action="store_true", required=False, help=argparse.SUPPRESS)
    parser.add_argument(
        "--left-only", "-l", required=False, action="store_true", help="output left-hand differences only"
    )
    parser.add_argument(
        "--right-only", "-r", required=False, action="store_true", help="output right-hand differences only"
    )
    parser.add_argument(
        "--name-only",
        "-n",
        required=False,
        default=False,
        action="store_true",
        help="use only names and not size for right and left comparison",
    )
    parser.add_argument("VOLPATH1", help="left volume and path to compare")
    parser.add_argument("VOLPATH2", help="right volume and path to compare")
    parser.parse_args()
    args = parser.parse_args()
    if args.make_test_vols:
        TestQ.make_vols()
        print("test volumes s1 and s3 created")
        sys.exit(0)
    if args.test:
        if not os.path.exists("/tmp/s1") or not os.path.exists("/tmp/s2"):
            print(f"run with {sys.argv[0]} --make-test-vols first", file=sys.stderr)
            sys.exit(1)
        unittest.main(argv=["first-arg-is-ignored"], exit=True)
    query = build_query(args)
    if args.verbose:
        print("fetching differences")
    fetch_rows(args, query)
    sys.exit(0)


if __name__ == "__main__":
    main()
