/* legal disclaimer in /opt/starfish/data/starfish/sql-copyright-and-license.md */

CREATE EXTENSION IF NOT EXISTS plpython3u;


CREATE OR REPLACE FUNCTION tags_report_plpython(flavour VARCHAR)
  RETURNS TABLE (
    volume_id BIGINT,
    zone_tag_name_id BIGINT,
    tag_name_id BIGINT,
    atime_age TEXT,
    mtime_age TEXT,
    logical_size BIGINT,
    physical_size BIGINT,
    count BIGINT
  )
AS $$
    from collections import Counter, namedtuple
    from operator import itemgetter
    import time

    def log(msg):
        plpy.log(f"[report][{flavour}] {msg}")

    ExcludedSubtree = namedtuple('ExcludedSubtree', ['node_path', 'node_path_id', 'paths', 'ids', 'node_labels'])

    TimeBucket = namedtuple('TimeBucket', ['min_days', 'max_days', 'description'])

    # Preferred way of changing buckets in custom tags and zone namespace tags reports
    # is to run `configure_redash buckets` command.
    rows = plpy.execute(f"SELECT lower_bound_month, label FROM sf_reports.age_buckets ORDER BY lower_bound_month")
    last_bound_days, last_label = 0, rows[0]['label']
    time_buckets = []
    def months_to_days(months):
        return months // 12 * 365 + months % 12 * 30

    for row in rows[1:]:
        max_days = months_to_days(row['lower_bound_month'])
        time_buckets.append(TimeBucket(min_days=last_bound_days, max_days=max_days - 1, description=last_label))
        last_label = row['label']
        last_bound_days = max_days
    time_buckets.append(TimeBucket(min_days=last_bound_days, max_days=None, description=last_label))
    time_buckets = time_buckets[::-1]


    class PathsTrieNode(dict):
        def __init__(self, labels=None, data=None):
            super().__init__()

            # marks that there is a path that ends here. For paths a/b/c and a/d/e we will have structure
            #             ''
            #             |
            #             a
            #          /     \
            #        b         d
            #       /           \
            #      c             e
            # (added=True)   (added=True)
            self.added = False
            self._labels = None  # do not make it empty set by default to save space
            self.update_labels(labels)
            self.data = None  # do not make it empty dict by default to save space
            self.update_data(data)

        def update_labels(self, labels):
            if not labels:
                return
            if self._labels is None:
                self._labels = set()
            self._labels.update(labels)

        def update_data(self, data):
            if data is None:
                return
            if isinstance(data, int):
                self.data = data
            else:
                if self.data is None:
                    self.data = {}
                self.data.update(data)

        def get_labels(self):
            if self._labels is None:
                return set()
            return self._labels

        def get_data(self):
            return self.data


    class PathsTrie:
        """
        This class keeps trie like structure where each node is single dir/file name from entry path.
        So for path dir/subdir/file we split it into nodes dir <- subdir <- file.
        With that structure we can easily find unique root paths for set of paths.
        To each node of trie we can assign labels and then ask only for nodes which are under given labels.
        Please see doctests for example.
        WARING: This class is not memory efficient and should not be used for large data sets
        Inefficiency is hidden in PathTrieNode object which is subclass of python dict enriched with additional
        set o labels. For example simple trie PathsTrie(paths=['ala/ma/kota']) allocates 4 nodes and occupies 1.7 KB
        of memory. Trie with 1M of unique paths (for paths: "1", "2", "3", ...) takes ~650MB of memory.
        Using labels will increase occupied memory by ~1KB for each node with labels (+ size of labels).
        >>> sorted(PathsTrie(['a/b/c', 'd/e/f', 'a/b']).to_list())
        ['a/b', 'a/b/c', 'd/e/f']
        >>> PathsTrie(['a/b/c', 'a/b/c/d/e']).unique_roots() == ['a/b/c']
        True
        >>> PathsTrie(['a/b/c1', 'a/b/c2', 'a/b2']).unique_roots() == ['a/b/c1', 'a/b/c2', 'a/b2']
        True
        >>> PathsTrie(['', 'a/b/c', 'a/b/c/d/e']).unique_roots()
        ['']
        >>> PathsTrie([]).unique_roots()
        []
        >>> pt = PathsTrie()
        >>> pt.add_path('a/b', labels={'1'})
        >>> pt.add_path('a/b/c', labels={'2'})
        >>> pt.add_path('a/b/c/d', labels={'3'})
        >>> pt.add_path('z', labels={'3'})
        >>> sorted(pt.to_list(required_labels={'2'})) == ['a/b/c', 'a/b/c/d']
        True
        >>> sorted(pt.unique_roots(required_labels={'2'})) == ['a/b/c']
        True
        >>> sorted(pt.to_list(required_labels={'2', '3'})) == ['a/b/c/d']
        True
        >>> sorted(pt.to_list(required_labels={'3'})) == ['a/b/c/d', 'z']
        True
        >>> sorted(pt.unique_roots(required_labels={'3'})) == ['a/b/c/d', 'z']
        True
        >>> pt = PathsTrie()
        >>> pt.add_paths(['a', 'a/b', 'b', 'a/b/c'], labels=None, skip_if_ancestor_present=True)
        >>> sorted(pt.to_list())
        ['a', 'b']
        >>> sorted(pt.unique_roots())
        ['a', 'b']
        >>> pt = PathsTrie()
        >>> pt.add_paths(["a/b/c", "a/b/c/d", "a/b", "a", "a/b/c", ""], skip_if_child_present=True)
        >>> sorted(pt.to_list()) == ['a/b/c', 'a/b/c/d']
        True
        >>> pt = PathsTrie()
        >>> pt.add_path('a/b', labels={'x', 'y'})
        >>> pt.add_path('a/b/c', labels={'x'}, skip_if_ancestor_present=True)
        >>> sorted(pt.to_list())
        ['a/b']
        >>> pt.add_path('a/b/c', labels={'z'}, skip_if_ancestor_present=True)
        >>> sorted(pt.to_list())
        ['a/b', 'a/b/c']
        >>> pt = PathsTrie()
        >>> pt.add_path('a/b', labels={'x'})
        >>> pt.add_path('a/b', labels={'y'})
        >>> pt.add_path('a/b', labels={'z'})
        >>> sorted(pt.to_list(required_labels={'x', 'y', 'z'}))
        ['a/b']
        >>> sorted(pt.get_inherited_labels('a/b'))
        ['x', 'y', 'z']
        >>> pt = PathsTrie()
        >>> pt.add_path('a', labels={'x'})
        >>> pt.add_path('a/b', labels={'y'})
        >>> pt.add_path('a/b/c/d', labels={'z'})
        >>> sorted(pt.get_inherited_labels('a/b/c'))
        ['x', 'y']
        >>> sorted(pt.get_inherited_labels('a/b/c/d'))
        ['x', 'y', 'z']
        >>> sorted(pt.get_inherited_labels('a/b/c/d/e/f'))
        ['x', 'y', 'z']
        >>> sorted(pt.get_inherited_labels('k/l/m'))
        []
        >>> sorted(pt.get_all_labels())
        ['x', 'y', 'z']
        """

        def __init__(self, paths=None, labels=None, data=None):
            self.root = PathsTrieNode()
            if paths:
                self.add_paths(paths, labels=labels)

        @classmethod
        def from_node(cls, node: PathsTrieNode):
            trie = PathsTrie()
            trie.root = node
            return trie

        def add_paths(self, paths, labels=None, data=None, skip_if_ancestor_present=False, skip_if_child_present=False):
            if not paths:
                return
            for path in paths:
                self.add_path(
                    path,
                    labels=labels,
                    data=data,
                    skip_if_ancestor_present=skip_if_ancestor_present,
                    skip_if_child_present=skip_if_child_present,
                )

        def add_path(self, path, labels=None, data=None, skip_if_ancestor_present=False, skip_if_child_present=False):
            if skip_if_child_present:
                node = self.get_node(path)
                if node is not None:  # child is present
                    return
            labels = set() if not labels else labels
            current_node = self.root
            path_parts = path.split('/') if path else []
            for path_part in path_parts:
                if current_node.added and skip_if_ancestor_present and labels.issubset(current_node.get_labels()):
                    return
                current_node = current_node.setdefault(path_part, PathsTrieNode())
            current_node.added = True
            current_node.update_labels(labels)
            current_node.update_data(data)

        def get_node(self, path):
            current_node = self.root
            path_parts = path.split('/') if path else []
            for path_part in path_parts:
                current_node = current_node.get(path_part, None)
                if current_node is None:
                    return None
            return current_node

        def get_inherited_labels(self, path):
            current_node = self.root
            inherited_labels = set(current_node.get_labels())
            path_parts = path.split('/') if path else []
            for path_part in path_parts:
                current_node = current_node.get(path_part, None)
                if current_node is None:
                    return inherited_labels
                inherited_labels.update(current_node.get_labels())
            return inherited_labels

        @staticmethod
        def join_path(path, key):
            # do not want to use os.path.join to always use '/' as separator
            if path:
                return path + '/' + key
            return key

        def _rec_iter(self, path, current_node, required_labels, inherited_labels, unique_roots):
            added_labels = set()
            paths = []
            if current_node.added:
                added_labels = current_node.get_labels() - inherited_labels
                inherited_labels.update(added_labels)

                if callable(required_labels):
                    labels_ok = required_labels(inherited_labels)
                else:
                    labels_ok = required_labels <= inherited_labels
                if labels_ok:
                    if unique_roots:
                        inherited_labels -= added_labels
                        return [path]
                    else:
                        paths.append(path)

            for key in current_node:
                paths.extend(
                    self._rec_iter(
                        self.join_path(path, key), current_node[key], required_labels, inherited_labels, unique_roots
                    )
                )
            inherited_labels -= added_labels
            return paths

        def to_list(self, required_labels=None):
            """
            Returns list of nodes in tree with required labels. If required labels is None then it returns all paths in trie
            """
            required_labels = set() if not required_labels else required_labels
            return self._rec_iter(
                path='', current_node=self.root, required_labels=required_labels, inherited_labels=set(), unique_roots=False
            )

        def unique_roots(self, required_labels=None):
            """
            Returns unique roots in tree with required labels.

            >>> pt = PathsTrie()
            >>> pt.add_path('a', labels={'label1'})
            >>> pt.add_path('a/b/c', labels={'label2'})
            >>> pt.add_path('x', labels={'label2'})
            >>> pt.add_path('x/y/z', labels={'label1'})
            >>> sorted(pt.unique_roots(lambda labels: 'label1' in labels and 'label2' in labels))
            ['a/b/c', 'x/y/z']
            """
            required_labels = set() if not required_labels else required_labels
            return self._rec_iter(
                path='', current_node=self.root, required_labels=required_labels, inherited_labels=set(), unique_roots=True
            )

        def unique_roots_with_data(self, required_labels=None):
            """
            >>> # data field can contain dict
            >>> t = PathsTrie()
            >>> path_to_number = {'a/b': 2, 'a/b/c': 3, 'a/d': 4, 'a/d/e': 5, 'a/d/e/f': 6}
            >>> for path, number in path_to_number.items():
            ...   t.add_path(path, data={'id': number})
            >>> sorted(t.unique_roots_with_data())
            [('a/b', {'id': 2}), ('a/d', {'id': 4})]
            >>> # data field can contain int
            >>> t = PathsTrie()
            >>> for path, number in path_to_number.items():
            ...   t.add_path(path, data=number)
            >>> sorted(t.unique_roots_with_data())
            [('a/b', 2), ('a/d', 4)]
            """
            root_paths = self.unique_roots(required_labels=required_labels)
            result = []
            for root_path in root_paths:
                node = self.get_node(root_path)
                result.append((root_path, node.data))
            return result

        def _get_all_labels_rec(self, node: PathsTrieNode, labels: set):
            labels |= node.get_labels()
            for path_part in node:
                labels |= self._get_all_labels_rec(node.get(path_part), labels)  # type: ignore
            return labels

        def get_all_labels(self):
            return self._get_all_labels_rec(self.root, labels=set())

        def _get_nodes_without_labels_rec(self, path, current_node, labels):
            paths_with_labels = []
            ids_with_labels = []
            if current_node.added:
                node_labels = current_node.get_labels()
                different_labels = node_labels - labels
                if different_labels:
                    return [path], [current_node.data]

            for key in current_node:
                paths, ids = self._get_nodes_without_labels_rec(self.join_path(path, key), current_node[key], labels)
                paths_with_labels.extend(paths)
                ids_with_labels.extend(ids)

            return paths_with_labels, ids_with_labels

        def get_roots_of_trees_with_different_labels(self, path, labels):
            node = self.get_node(path)
            return self._get_nodes_without_labels_rec(path=path, current_node=node, labels=labels)

        def get_list_of_excluding_subtrees(self, node_path, node_path_id):
            # this very specialized method used to generate SQL queries for
            # for tags report which will efficiently get those directories from database which
            # have the same inherited tags. For given node_path it will return list of tuples.
            # Each tuple contains path, excluded_subpaths and inherited labels (tags).
            # For example if we have root directory with tag assigned to it and no other tags
            # in this subtree this method will return [(node_path, [], inherited_labels_for_node_path)].
            # This will allow us to create query:
            # SELECT * FROM sf.dir_current WHERE volume_id=X
            #     AND (path=node_path or path LIKE subtree_pattern(node_path))
            # to get all directories for this subtree and we know that all those dirs have inherited_tags same as node_path.
            # Now lets assume that we add new tag to path node_path/subdir and add it with proper label to PathsTrie.
            # In such case result of this method will change to:
            # [
            #   (node_path, [node_path/subdir], inherited_labels_for_node_path),
            #   (node_path/subdir, [], inherited_labels_for_node_path + new_tag),
            # ]
            # Then we can use this result to create queries with "excluded subtrees:
            # SELECT * FROM sf.dir_current WHERE volume_id=X
            #     AND (path=node_path or path LIKE subtree_pattern(node_path))
            #     AND path!=node_path/subdir AND path NOT LIKE subtree_pattern(node_path/subdir))
            # Query above is for the first element of result. Here is query for second:
            # SELECT * FROM sf.dir_current WHERE volume_id=X
            #     AND (path=node_path/subdir or path LIKE subtree_pattern(node_path/subdir))

            node_inherited_labels = self.get_inherited_labels(node_path)
            paths_to_exclude, ids_to_exclude = self.get_roots_of_trees_with_different_labels(node_path, node_inherited_labels)
            if paths_to_exclude:
                result = [ExcludedSubtree(node_path, node_path_id, set(paths_to_exclude), set(ids_to_exclude), node_inherited_labels)]
                for path, path_id in zip(paths_to_exclude, ids_to_exclude):
                    result.extend(self.get_list_of_excluding_subtrees(path, path_id))
                return result
            else:
                return [ExcludedSubtree(node_path, node_path_id, set(), set(), node_inherited_labels)]


    def fast_log_2(x):
        """
        Returns ~ int(log(x, 2))
        This seems to be faster than log(time_delta, 2)
        >>> fast_log_2(0)
        0
        >>> fast_log_2(1)
        0
        >>> fast_log_2(2)
        1
        >>> fast_log_2(1023)
        9
        >>> fast_log_2(1024)
        10
        """
        return len(bin(x)) - 3

    class Histogram:
        def __init__(self, name):
            self._counter = Counter()
            self.total_time = 0
            self._name = name
            self._max_value = 0
            self._min_value = 2 ** 32

        def add_sample(self, seconds):
            time_delta_ms = seconds * 1000
            self._counter.update([fast_log_2(int(time_delta_ms))])
            self.total_time += time_delta_ms
            if time_delta_ms > self._max_value:
                self._max_value = time_delta_ms
            if time_delta_ms < self._min_value:
                self._min_value = time_delta_ms

        def __str__(self):
            ret = []
            events_count = sum(self._counter.values())
            if events_count > 0:
                ret.append(
                    f'{self._name}: total: {self.total_time:.2f}ms;'
                    f' count: {events_count};'
                    f' avg: {self.total_time / events_count:.2f}ms;'
                    f' min: {self._min_value:.2f}ms;'
                    f' max: {self._max_value:.2f}ms'
                )
            else:
                ret.append(f' {self._name}:')
            for value, count in sorted(self._counter.items(), key=itemgetter(0)):
                v1 = str(pow(2, value) if value > 0 else 0).rjust(7)
                v2 = str(pow(2, value + 1)).rjust(7)
                count = str(count).rjust(8)
                ret.append(f' {v1}ms to {v2}ms - {count}')
            return '\n'.join(ret)


    class GroupingMeasureTime:
        def __init__(self):
            self._stat_name_to_histogram = {}
            self._stat_name_stack = []
            self._start_time_stack = []

        def with_stat_name(self, stat_name):
            self._stat_name_stack.append(stat_name)
            return self

        def add_sample(self, stat_name, duration):
            histogram = self._get_histogram(stat_name)
            histogram.add_sample(duration)

        def _get_histogram(self, stat_name):
            if stat_name not in self._stat_name_to_histogram:
                self._stat_name_to_histogram[stat_name] = Histogram(stat_name)
            return self._stat_name_to_histogram[stat_name]

        def __enter__(self):
            self._start_time_stack.append(time.time())
            if len(self._start_time_stack) != len(self._stat_name_stack):
                raise AssertionError("You used GroupingMeasureTime without .with_stat_name!")
            return self

        def __exit__(self, exc_type, _exc_val, _exc_tb):
            if exc_type is not None:  # exception was thrown
                return
            start_time = self._start_time_stack.pop()
            stat_name = self._stat_name_stack.pop()
            duration = time.time() - start_time
            histogram = self._get_histogram(stat_name)
            histogram.add_sample(duration)

        def print_all_stats(self, print_fun):
            for stat_name, histogram in self._stat_name_to_histogram.items():
                print_fun(f"{str(histogram)}")

        def reset(self):
            self._stat_name_to_histogram = {}
            self._stat_name_stack = []
            self._start_time_stack = []

    def prepare_table_entries_with_explicit_tags(sql):
        with mt.with_stat_name("tags__entries_with_explicit_tags"):
            plpy.execute(f"CREATE TEMP TABLE tags__entries_with_explicit_tags ON COMMIT DROP AS ({sql})")

    def prepare_table_dirs_with_explicit_tags():
        with mt.with_stat_name("tags__dirs_with_explicit_tags"):
            plpy.execute(f"""
                CREATE TEMP TABLE tags__dirs_with_explicit_tags ON COMMIT DROP AS
                    SELECT dir.volume_id, dir.id, ancestor_ids, dir.path, tag.zone_tag_name_ids, tag.tag_name_ids
                    FROM sf.dir_current AS dir
                        JOIN tags__entries_with_explicit_tags AS tag
                            ON dir.id = tag.id AND dir.volume_id = tag.volume_id
                    ORDER BY dir.volume_id, dir.path
            """)
            plpy.execute(f"ANALYZE tags__dirs_with_explicit_tags")

    class ZoneTagNameId(int):
        """ ZoneTagNameIds and TagNameIds are kept in a single set. They have to differ from each other.
        """
        def __hash__(self):
            return hash((int(self), self.__class__))

        def __eq__(self, other):
            return isinstance(other, ZoneTagNameId) and int(self) == int(other)

        def __str__(self):
            return str(int(self))

        def __repr__(self):
            return f"ZoneTagNameId({int(self)})"

    class TagNameId(int):
        def __hash__(self):
            return hash((int(self), self.__class__))

        def __eq__(self, other):
            return isinstance(other, TagNameId) and int(self) == int(other)

        def __str__(self):
            return str(int(self))

        def __repr__(self):
            return f"TagNameId({int(self)})"

    def create_tries(relname):
        with mt.with_stat_name(f"create_tries"):
            rows = plpy.execute(f"SELECT volume_id, id, path, zone_tag_name_ids, tag_name_ids FROM {relname}")
            trie_per_vol_id = {}
            for row in rows:
                vol_id = row["volume_id"]
                if vol_id not in trie_per_vol_id:
                    trie_per_vol_id[vol_id] = PathsTrie()
                trie = trie_per_vol_id[vol_id]
                is_dir = row["path"] is not None
                if is_dir:
                    zone_labels = set()
                    if row["zone_tag_name_ids"]:
                        zone_labels = set(ZoneTagNameId(zone_tag_name_id) for zone_tag_name_id in row["zone_tag_name_ids"])
                    tag_labels = set()
                    if row["tag_name_ids"]:
                        tag_labels = set(TagNameId(tag_name_id) for tag_name_id in row["tag_name_ids"])
                    trie.add_path(row["path"], labels=zone_labels | tag_labels, data=row["id"], skip_if_ancestor_present=True)
        return trie_per_vol_id

    def create_table_dirs_with_inherited_tags():
        plpy.execute(f"""
            CREATE TEMP TABLE tags__dirs_with_inherited_tags (
                volume_id BIGINT,
                id BIGINT,
                zone_tag_name_ids BIGINT[],
                tag_name_ids BIGINT[],
                min_atime_days float8,
                max_atime_days float8,
                min_mtime_days float8,
                max_mtime_days float8,
                files_count bigint,
                logical_size bigint,
                blocks bigint
            ) ON COMMIT DROP
        """)

    def create_table_dirs_in_atime_mtime_buckets():
        plpy.execute(f"""
            CREATE TEMP TABLE tags__dirs_in_atime_mtime_buckets (
                id BIGINT,
                volume_id BIGINT,
                atime_bucket TEXT,
                mtime_bucket TEXT,
                zone_tag_name_ids BIGINT[],
                tag_name_ids BIGINT[],
                files_count BIGINT,
                logical_size BIGINT,
                blocks BIGINT
            ) ON COMMIT DROP
        """)
        plpy.execute(f"""
            CREATE TEMP TABLE tags__dirs_with_different_atime_and_mtime_age ON COMMIT DROP AS
                SELECT * FROM tags__dirs_in_atime_mtime_buckets -- this table is empty
        """)

    def get_volume_ids():
        with mt.with_stat_name("get_volumes_ids"):
            rows = plpy.execute(f"""
                SELECT DISTINCT volume_id FROM tags__entries_with_explicit_namespace_tags
                UNION
                SELECT DISTINCT volume_id FROM tags__entries_with_explicit_zone_tags
            """)
            return list(row["volume_id"] for row in rows)

    def get_volumes_sizes(trie_per_vol_id):
        with mt.with_stat_name("get_volumes_size"):
            dir_ids = []
            volume_id_to_files_count = {}
            for vol_id, trie in trie_per_vol_id.items():
                paths_with_ids = trie.unique_roots_with_data()
                for _, id in paths_with_ids:
                    dir_ids.append(id)
            if dir_ids:
                dir_ids_str = ', '.join([str(id) for id in dir_ids])
                rows = plpy.execute(f"""
                    select
                        volume_id,
                        COALESCE(SUM((rec_aggrs->>'files')::BIGINT), 0) as files
                    FROM sf.dir_current
                    WHERE id IN ({dir_ids_str})
                    GROUP BY volume_id
                """)
            else:
                rows = []
            files_total = 0
            for row in rows:
                vol_id = row["volume_id"]
                files = row["files"]
                files_total += files
                volume_id_to_files_count[vol_id] = files
                log(f"vol: {vol_id} files: {files:,}")
            log(f"Total files: {files_total:,}")
            if files_total <= 0:
                files_total = 1  # this is guard - we cannot return as there might be some dirs with no aggregates yet
        return volume_id_to_files_count, files_total

    def log_stats_if_necessary(last_print_stats_time):
        if time.time() - last_print_stats_time >= 3600:
            mt.print_all_stats(log)  # print all stats at beginning of each vol
            last_print_stats_time = time.time()

    def truncate_dirs_temp_tables():
        with mt.with_stat_name("truncate_dirs_inherited"):
            plpy.execute(f"""TRUNCATE tags__dirs_with_inherited_tags;""")
            plpy.execute(f"""TRUNCATE tags__dirs_in_atime_mtime_buckets;""")
            plpy.execute(f"""TRUNCATE tags__dirs_with_different_atime_and_mtime_age;""")

    def get_quoted(str_):
        return str_.replace("'", "''")

    def calculate_inherited_tags_for_volume(vol_id, trie, required_labels_func=None):
        explicitly_tagged_roots_with_ids = trie.unique_roots_with_data(required_labels=required_labels_func)
        insert_recursive_count = 0
        insert_recursive_dirs = 0
        insert_subtree_count = 0
        insert_subtree_dirs = 0
        now = plpy.execute('SELECT EXTRACT(EPOCH FROM(current_timestamp)) as now')[0]["now"]
        now = int(now)

        with mt.with_stat_name("insert_dirs_into_tags__dirs_with_inherited_tags"):
            for path, id in explicitly_tagged_roots_with_ids:
                excluded_subtrees = trie.get_list_of_excluding_subtrees(path, id)
                max_excluded_paths_len = max([len(excluded_subtree.paths) for excluded_subtree in excluded_subtrees])
                if max_excluded_paths_len > GD.get('tags_calculation_with_recursion_threshold', 100):  # number 100 is based on STAR-5343
                    insert_recursive_count += 1
                    with mt.with_stat_name("insert_dirs_with_recursive"):
                        inherited_tag_name_ids = trie.get_inherited_labels(path)
                        zone_tag_name_ids = [tag_name_id for tag_name_id in inherited_tag_name_ids if isinstance(tag_name_id, ZoneTagNameId)]
                        if not zone_tag_name_ids:
                            # ZoneTagNameIds exist only in zone_namespace_tags report
                            zone_tag_name_ids_str = 'NULL'
                        else:
                            zone_tag_name_ids_str = ','.join(str(tag_name_id) for tag_name_id in zone_tag_name_ids)
                        tag_name_ids_str = ','.join(str(tag_name_id) for tag_name_id in inherited_tag_name_ids if isinstance(tag_name_id, TagNameId))
                        plpy.execute(f"""
                            CREATE TEMP TABLE tags__dirs_with_explicit_tags_filtered AS
                                SELECT * from tags__dirs_with_explicit_tags
                                WHERE volume_id = {vol_id} AND ancestor_ids && ARRAY[{id}]::BIGINT[]
                        """)
                        plpy.execute(f"""ANALYZE tags__dirs_with_explicit_tags_filtered""")
                        result = plpy.execute(f"""
                            WITH RECURSIVE dirs_with_inherited_tags_calculated_recursively AS (
                                SELECT
                                    dir.volume_id,
                                    dir.id,
                                    dir.parent_id,
                                    ARRAY[{zone_tag_name_ids_str}]::BIGINT[] zone_tag_name_ids,
                                    ARRAY[{tag_name_ids_str}]::BIGINT[] tag_name_ids,
                                    ({now} - (local_aggrs->'min'->>'atime')::FLOAT8) / 86400.0 AS min_atime_days,
                                    ({now} - (local_aggrs->'max'->>'atime')::FLOAT8) / 86400.0 AS max_atime_days,
                                    ({now} - (local_aggrs->'min'->>'mtime')::FLOAT8) / 86400.0 AS min_mtime_days,
                                    ({now} - (local_aggrs->'max'->>'mtime')::FLOAT8) / 86400.0 AS max_mtime_days,
                                    (local_aggrs->'total'->>'files')::BIGINT AS files_count,
                                    (local_aggrs->'total'->>'size')::BIGINT AS logical_size,
                                    COALESCE((local_aggrs->'total'->>'blocks_div_nlinks')::BIGINT, (local_aggrs->'total'->>'blocks')::BIGINT) AS blocks
                                    FROM tags__dirs_with_explicit_tags_filtered AS explicitly_tagged_dirs
                                    INNER JOIN sf.dir_current AS dir ON dir.id = explicitly_tagged_dirs.id
                                    WHERE explicitly_tagged_dirs.path = '{get_quoted(path)}' AND dir.volume_id={vol_id}

                                UNION ALL
                                -- select directories from tagged_dirs and propagate tags from the ancestors
                                SELECT
                                    child.volume_id,
                                    child.id,
                                    child.parent_id,
                                    ARRAY(SELECT DISTINCT UNNEST(parent.zone_tag_name_ids || CASE WHEN explicitly_tagged_dirs.zone_tag_name_ids IS NOT NULL THEN explicitly_tagged_dirs.zone_tag_name_ids ELSE ARRAY[]::BIGINT[] END)) as zone_tag_name_ids,
                                    ARRAY(SELECT DISTINCT UNNEST(parent.tag_name_ids || CASE WHEN explicitly_tagged_dirs.tag_name_ids IS NOT NULL THEN explicitly_tagged_dirs.tag_name_ids ELSE ARRAY[]::BIGINT[] END)) as tag_name_ids,
                                    ({now} - (local_aggrs->'min'->>'atime')::FLOAT8) / 86400.0 AS min_atime_days,
                                    ({now} - (local_aggrs->'max'->>'atime')::FLOAT8) / 86400.0 AS max_atime_days,
                                    ({now} - (local_aggrs->'min'->>'mtime')::FLOAT8) / 86400.0 AS min_mtime_days,
                                    ({now} - (local_aggrs->'max'->>'mtime')::FLOAT8) / 86400.0 AS max_mtime_days,
                                    (local_aggrs->'total'->>'files')::BIGINT AS files_count,
                                    (local_aggrs->'total'->>'size')::BIGINT AS logical_size,
                                    COALESCE((local_aggrs->'total'->>'blocks_div_nlinks')::BIGINT, (local_aggrs->'total'->>'blocks')::BIGINT) AS blocks
                                    FROM sf.dir_current AS child
                                    INNER JOIN dirs_with_inherited_tags_calculated_recursively AS parent ON parent.id = child.parent_id
                                    LEFT JOIN tags__dirs_with_explicit_tags_filtered AS explicitly_tagged_dirs ON child.id = explicitly_tagged_dirs.id
                                    WHERE child.volume_id = {vol_id}

                            ) INSERT INTO tags__dirs_with_inherited_tags
                                SELECT volume_id, id, zone_tag_name_ids, tag_name_ids, min_atime_days, max_atime_days, min_mtime_days, max_mtime_days, files_count, logical_size, blocks
                                FROM dirs_with_inherited_tags_calculated_recursively
                        """)
                        insert_recursive_dirs += result.nrows() or 0
                        plpy.execute(f"""
                            DROP TABLE tags__dirs_with_explicit_tags_filtered
                        """)
                else:
                    insert_subtree_count += 1
                    for subtree_path, subtree_id, excluded_paths, excluded_ids, inherited_tag_name_ids in excluded_subtrees:
                        zone_tag_name_ids = [tag_name_id for tag_name_id in inherited_tag_name_ids if isinstance(tag_name_id, ZoneTagNameId)]
                        if not zone_tag_name_ids:
                            # ZoneTagNameIds exist only in zone_namespace_tags report
                            zone_tag_name_ids_str = 'NULL'
                        else:
                            zone_tag_name_ids_str = ','.join(str(tag_name_id) for tag_name_id in zone_tag_name_ids)
                        tag_name_ids_str = ','.join(str(tag_name_id) for tag_name_id in inherited_tag_name_ids if isinstance(tag_name_id, TagNameId))

                        excluded_paths_str = [
                            f" AND NOT ancestor_ids && ARRAY[{','.join(str(excluded_id) for excluded_id in excluded_ids)}]::BIGINT[]"
                        ]
                        subtree_str = f"AND ancestor_ids && ARRAY[{subtree_id}]::BIGINT[]"

                        insert_stmt = f"""
                            INSERT INTO tags__dirs_with_inherited_tags
                                SELECT
                                    volume_id,
                                    id,
                                    ARRAY[{zone_tag_name_ids_str}]::BIGINT[] AS zone_tag_name_ids,
                                    ARRAY[{tag_name_ids_str}]::BIGINT[] AS tag_name_ids,
                                    ({now} - (local_aggrs->'min'->>'atime')::FLOAT8) / 86400.0 AS min_atime_days,
                                    ({now} - (local_aggrs->'max'->>'atime')::FLOAT8) / 86400.0 AS max_atime_days,
                                    ({now} - (local_aggrs->'min'->>'mtime')::FLOAT8) / 86400.0 AS min_mtime_days,
                                    ({now} - (local_aggrs->'max'->>'mtime')::FLOAT8) / 86400.0 AS max_mtime_days,
                                    (local_aggrs->'total'->>'files')::BIGINT AS files_count,
                                    (local_aggrs->'total'->>'size')::BIGINT AS logical_size,
                                    COALESCE((local_aggrs->'total'->>'blocks_div_nlinks')::BIGINT, (local_aggrs->'total'->>'blocks')::BIGINT) AS blocks
                                FROM sf.dir_current WHERE volume_id={vol_id}
                                        {subtree_str}
                                        {' '.join(excluded_paths_str)}
                        """
                        with mt.with_stat_name("insert_dirs_with_subtree"):
                            result = plpy.execute(insert_stmt)
                            insert_subtree_dirs += result.nrows() or 0

        plpy.execute("ANALYZE tags__dirs_with_inherited_tags")
        log(f"vol: {vol_id} insert subtree: {insert_subtree_count} ({insert_subtree_dirs} dirs) insert recursive: {insert_recursive_count} ({insert_recursive_dirs} dirs)")


    def get_interval_to_age_buckets_case(time_type, time_buckets):
        """
        >>> time_buckets = [
        >>>     TimeBucket(min_days=365, max_days=None, description='Previous Years: > 1'),
        >>>     TimeBucket(min_days=180, max_days=364, description='Previous Months: 6-12'),
        >>>     TimeBucket(min_days=0, max_days=179, description='Previous Months: 0-6'),
        >>> [
        >>> get_interval_to_age_buckets_case('atime', time_buckets)
        WHEN current_timestamp - file.atime >= interval '365 days' THEN 'Previous Years: > 1'
        WHEN current_timestamp - file.atime >= interval '180 days' THEN 'Previous Months: 6-12'
        WHEN current_timestamp - file.atime >= interval '0 days' THEN 'Previous Months: 0-6'
        ELSE 'future'
        >>> get_interval_to_age_buckets_case('mtime', time_buckets)
        WHEN current_timestamp - file.mtime >= interval '365 days' THEN 'Previous Years: > 1'
        WHEN current_timestamp - file.mtime >= interval '180 days' THEN 'Previous Months: 6-12'
        WHEN current_timestamp - file.mtime >= interval '0 days' THEN 'Previous Months: 0-6'
        ELSE 'future'
        >>> get_interval_to_age_buckets_case('foo', time_buckets)
        ValueError: time_type should be either atime or mtime
        """
        if time_type not in ('atime', 'mtime'):
            raise ValueError('time_type should be either atime or mtime')

        age_interval = f'current_timestamp - file.{time_type}'
        result = ''

        for bucket in time_buckets:
            result += f"WHEN {age_interval} >= interval '{bucket.min_days} days' THEN '{bucket.description}'\n"

        result += "ELSE 'future'"

        return result


    def time_buckets_to_sql_case(time_type, time_buckets):
        """
        >>> time_buckets = [
        >>>     TimeBucket(min_days=365, max_days=None, description='Previous Years: > 1'),
        >>>     TimeBucket(min_days=180, max_days=364, description='Previous Months: 6-12'),
        >>>     TimeBucket(min_days=0, max_days=179, description='Previous Months: 0-6'),
        >>> [
        >>> time_buckets_to_sql_case('atime', time_buckets)
        WHEN max_mtime_days >= 365 THEN 'Previous Years: > 1'
        WHEN max_atime_days >= 180 AND min_atime_days <= 364 THEN 'Previous Months: 6-12'
        WHEN max_atime_days >= 0 AND min_atime_days <= 179 THEN 'Previous Months: 0-6'
        WHEN min_atime_days < 0 THEN 'future'
        ELSE 'different_groups
        >>> time_buckets_to_sql_case('mtime', time_buckets)
        WHEN max_mtime_days >= 365 THEN 'Previous Years: > 1'
        WHEN max_mtime_days >= 180 AND min_mtime_days <= 364 THEN 'Previous Months: 6-12'
        WHEN max_mtime_days >= 0 AND min_mtime_days <= 179 THEN 'Previous Months: 0-6'
        WHEN min_mtime_days < 0 THEN 'future'
        ELSE 'different_groups
        >>> time_buckets_to_sql_case('foo', time_buckets)
        ValueError: time_type should be either atime or mtime
        """
        if time_type not in ('atime', 'mtime'):
            raise ValueError('time_type should be either atime or mtime')

        result = ''
        for bucket in time_buckets:
            if bucket.max_days is None:
                result += f"WHEN max_{time_type}_days >= {bucket.min_days} THEN '{bucket.description}'\n"
            else:
                result += f"WHEN max_{time_type}_days >= {bucket.min_days} AND min_{time_type}_days <= {bucket.max_days} THEN '{bucket.description}'\n"
        result += f"WHEN min_{time_type}_days < 0 THEN 'future'\nELSE 'different_groups'"
        return result


    def group_dirs_into_atime_mtime_buckets():
        with mt.with_stat_name("create_dirs_in_atime_mtime_buckets"):
            plpy.execute(f"""
                INSERT INTO tags__dirs_in_atime_mtime_buckets
                    SELECT
                        id,
                        volume_id,
                        CASE
                            {time_buckets_to_sql_case('atime', time_buckets)}
                        END as atime_bucket,
                        CASE
                            {time_buckets_to_sql_case('mtime', time_buckets)}
                        END as mtime_bucket,
                        zone_tag_name_ids,
                        tag_name_ids,
                        files_count,
                        logical_size,
                        blocks
                    FROM tags__dirs_with_inherited_tags
            """)


    def move_dirs_with_different_atime_mtime_groups_to_separate_table():
        with mt.with_stat_name("delete_dirs_with_different_age_groups"):
            plpy.execute(f"""
                WITH tmp AS (
                    DELETE FROM tags__dirs_in_atime_mtime_buckets
                    WHERE atime_bucket = 'different_groups' OR mtime_bucket = 'different_groups'
                    RETURNING *
                )
                INSERT INTO tags__dirs_with_different_atime_and_mtime_age SELECT * FROM tmp
            """)
        with mt.with_stat_name("analyze"):
            plpy.execute("ANALYZE tags__dirs_in_atime_mtime_buckets;")
            plpy.execute("ANALYZE tags__dirs_with_different_atime_and_mtime_age;")

            rows = plpy.execute("SELECT COALESCE(SUM(files_count), 0) as files_count, COUNT(*) as dirs_count FROM tags__dirs_with_different_atime_and_mtime_age")
            log(f"""Files not from local_aggrs: {rows[0]['files_count']:,} (in {rows[0]['dirs_count']:,})""")


    def calculate_result_from_local_aggrs():
        with mt.with_stat_name("result_from_local_aggrs"):
            cursor = plpy.cursor(f"""
                SELECT
                    atime_bucket,
                    mtime_bucket,
                    zone_tag_name_id,
                    unnest(tag_name_ids) AS tag_name_id,
                    SUM(files_count) AS files_count,
                    SUM(logical_size) AS logical_size,
                    SUM(blocks * 512) AS physical_size
                FROM tags__dirs_in_atime_mtime_buckets
                    CROSS JOIN unnest(tags__dirs_in_atime_mtime_buckets.zone_tag_name_ids) AS zone_tag_name_id
                GROUP BY zone_tag_name_id, tag_name_id, atime_bucket, mtime_bucket
            """)

            fetch_size = 1000
            rows = cursor.fetch(fetch_size)
            while rows:
                for row in rows:
                    yield vol_id, row["zone_tag_name_id"], row["tag_name_id"], row["atime_bucket"], row["mtime_bucket"], row["logical_size"], row["physical_size"], row["files_count"]
                rows = cursor.fetch(fetch_size)
            cursor.close()

    def calculate_result_from_dirs_with_different_atime_mtime_groups():
        with mt.with_stat_name("files_with_different_age_groups"):
            query = f"""
                SELECT
                    volume_id,
                    SUM(logical_size) as logical_size,
                    SUM(physical_size) as physical_size,
                    SUM(files_count) as files_count,
                    file_atime_age,
                    file_mtime_age,
                    zone_tag_name_id,
                    unnest(tag_name_ids) as tag_name_id
                FROM (
                    -- we make nested query because we dont want to unnest tag_name_ids array
                    -- in this query as it can double or triple query time (I believe that rows are multiplied first
                    -- and then costly CASE statements are calculated for each multiplied row)
                    SELECT
                        file.volume_id,
                        SUM(file.size) AS logical_size,
                        SUM(file.blocks * 512 / CASE WHEN file.nlinks > 0 THEN file.nlinks ELSE 1 END) AS physical_size,
                        COUNT(*) AS files_count,
                        CASE
                            {get_interval_to_age_buckets_case('atime', time_buckets)}
                        END as file_atime_age,
                        CASE
                            {get_interval_to_age_buckets_case('mtime', time_buckets)}
                        END as file_mtime_age,
                        (dir.zone_tag_name_ids)::BIGINT[] AS zone_tag_name_ids,
                        (dir.tag_name_ids)::BIGINT[] AS tag_name_ids
                    FROM sf.file_current AS file
                        JOIN tags__dirs_with_different_atime_and_mtime_age AS dir ON file.parent_id = dir.id
                    WHERE file.volume_id = {vol_id}
                    GROUP BY file.volume_id, file_atime_age, file_mtime_age, zone_tag_name_ids, tag_name_ids
                ) t
                CROSS JOIN unnest(t.zone_tag_name_ids) AS zone_tag_name_id
                GROUP BY volume_id, file_atime_age, file_mtime_age, zone_tag_name_id, tag_name_id
            """

            rows = plpy.execute(query)
            for row in rows:
                yield row["volume_id"], row["zone_tag_name_id"], row["tag_name_id"], row["file_atime_age"], row["file_mtime_age"], row["logical_size"], row["physical_size"], row["files_count"]

    def calculate_result_from_explicit_tags_on_files(trie, vol_id, include_zone_tag_name_ids):
        with mt.with_stat_name("explicit_tags"):
            cursor = plpy.cursor(f"""
                SELECT
                    file.parent_id,
                    dir.path AS parent_path,
                    file.volume_id,
                    SUM(file.size) AS logical_size,
                    SUM(file.blocks * 512 / CASE WHEN file.nlinks > 0 THEN file.nlinks ELSE 1 END) AS physical_size,
                    COUNT(*) AS files_count,
                    CASE
                        {get_interval_to_age_buckets_case('atime', time_buckets)}
                    END as atime_age,
                    CASE
                        {get_interval_to_age_buckets_case('mtime', time_buckets)}
                    END as mtime_age,
                    COALESCE(tag.zone_tag_name_ids, ARRAY[]::BIGINT[]) AS zone_tag_name_ids,
                    COALESCE(tag.tag_name_ids, ARRAY[]::BIGINT[]) AS tag_name_ids
                FROM sf.file_current AS file
                    JOIN tags__entries_with_explicit_tags AS tag ON file.id = tag.id
                    JOIN sf.dir_current AS dir ON file.parent_id = dir.id
                WHERE file.volume_id = {vol_id} AND dir.volume_id = {vol_id}
                GROUP BY file.parent_id, file.volume_id, zone_tag_name_ids, tag_name_ids, atime_age, mtime_age, dir.path
            """)
            fetch_size = 1000
            rows = cursor.fetch(fetch_size)
            while rows:
                for row in rows:
                    inherited_labels = trie.get_inherited_labels(row["parent_path"])
                    # int(label) because row contains raw ints which are not equal to ZoneTagNameId or TagNameId
                    inherited_zone_tag_name_ids = set(int(label) for label in inherited_labels if isinstance(label, ZoneTagNameId))
                    inherited_tag_name_ids = set(int(label) for label in inherited_labels if isinstance(label, TagNameId))
                    new_zone_tag_name_ids = set(row["zone_tag_name_ids"]) - inherited_zone_tag_name_ids
                    new_tag_name_ids = set(row["tag_name_ids"]) - inherited_tag_name_ids
                    if not inherited_zone_tag_name_ids and not include_zone_tag_name_ids:
                        if new_zone_tag_name_ids:
                            raise RuntimeError('ZoneTagNameIds exists only in zone_namespace_tags')
                        inherited_zone_tag_name_ids = {None,}
                    for zone_tag_name_id in inherited_zone_tag_name_ids:
                        for tag_name_id in new_tag_name_ids:
                            # notice that row with exactly the same vol_id, zone_tag_name_id, tag_name_id, atime_age, mtime_age can be
                            # returned more then once (once from dirs and once from explicit tags on files) so overall
                            # output of this function should be grouped by those fields
                            yield vol_id, zone_tag_name_id, tag_name_id, row["atime_age"], row["mtime_age"], row["logical_size"], row["physical_size"], row["files_count"]
                    for zone_tag_name_id in new_zone_tag_name_ids:
                        for tag_name_id in inherited_tag_name_ids | new_tag_name_ids:
                            yield vol_id, zone_tag_name_id, tag_name_id, row["atime_age"], row["mtime_age"], row["logical_size"], row["physical_size"], row["files_count"]

                rows = cursor.fetch(fetch_size)
            cursor.close()

    log(f"====================== START {flavour}_plpython =========================")
    mt = GroupingMeasureTime()

    if flavour == "tags":
        include_zone_tag_name_ids = False
        explicitly_tagged_entries_sql = """
            SELECT tvc.volume_id,
                   fs_entry_id as id,
                   array_agg(name_id) AS tag_name_ids,
                   ARRAY[]::BIGINT[] AS zone_tag_name_ids,
                    dc.name,
                    dc.path
            FROM sf.tag_value_current as tvc
            LEFT JOIN sf.dir_current AS dc ON (tvc.fs_entry_id = dc.id)
            WHERE name_id IN (
                SELECT tn.id
                FROM sf.tag_name AS tn
                         JOIN sf.tag_namespace AS tns ON tn.namespace_id = tns.id
                WHERE tns.name not in ('__archive')
            )  -- global tags, custom namespace tags + internal __zone tags. Discard internal __archive tags
            GROUP BY tvc.volume_id, fs_entry_id, dc.name, dc.path
        """
        required_labels_func = None
    elif flavour == "zone_namespace_tags":
        include_zone_tag_name_ids = True
        explicitly_tagged_entries_sql = """
            SELECT tvc.volume_id,
                   fs_entry_id as id,
                   array_agg(name_id) FILTER (WHERE tns.name = '__zone') AS zone_tag_name_ids,
                   array_agg(name_id) FILTER (WHERE tns.name != '' AND SUBSTRING(tns.name, 1, length('__')) != '__') AS tag_name_ids,
                   dc.name,
                   dc.path
            FROM sf.tag_value_current AS tvc
            LEFT JOIN sf.dir_current as dc on (tvc.fs_entry_id = dc.id)
                     JOIN sf.tag_name AS tn ON tvc.name_id = tn.id
                     JOIN sf.tag_namespace AS tns ON tn.namespace_id = tns.id
            GROUP BY tvc.volume_id, fs_entry_id, dc.name, dc.path
        """
        def in_zone_with_tagspace(labels):
            in_zone = any(isinstance(label, ZoneTagNameId) for label in labels)
            with_tagspace = any(isinstance(label, TagNameId) for label in labels)
            return in_zone and with_tagspace

        required_labels_func = in_zone_with_tagspace
    else:
        raise RuntimeError(f"Unexpected flavour: {flavour}")

    prepare_table_entries_with_explicit_tags(explicitly_tagged_entries_sql)
    prepare_table_dirs_with_explicit_tags()
    trie_per_vol_id = create_tries('tags__entries_with_explicit_tags')

    create_table_dirs_with_inherited_tags()
    create_table_dirs_in_atime_mtime_buckets()
    volume_id_to_files_count, files_total = get_volumes_sizes(trie_per_vol_id)

    files_done = 0
    files_count_for_vol = 0
    last_print_stats_time = time.time()
    for vol_id, trie in trie_per_vol_id.items():
        log_stats_if_necessary(last_print_stats_time)
        files_done += files_count_for_vol
        done_percent = float(files_done / files_total) * 100.0
        files_count_for_vol = volume_id_to_files_count.get(vol_id, 0)
        log(f"started vol: {vol_id} with {files_count_for_vol:,} files. Done: {files_done:,}/{files_total:,} ({done_percent:.2f}%)")
        truncate_dirs_temp_tables()

        calculate_inherited_tags_for_volume(
            vol_id,
            trie,
            required_labels_func=required_labels_func,
        )  # -> populates tags__dirs_with_inherited_tags
        group_dirs_into_atime_mtime_buckets()  # -> populates tags__dirs_in_atime_mtime_buckets by copying from tags__dirs_with_inherited_tags
        move_dirs_with_different_atime_mtime_groups_to_separate_table()  # -> creates tags__dirs_in_atime_mtime_buckets by moving 'different-groups' bucket from tags__dirs_in_atime_mtime_buckets

        for row in calculate_result_from_local_aggrs():  # returns data from tags__dirs_in_atime_mtime_buckets
            yield row

        for row in calculate_result_from_dirs_with_different_atime_mtime_groups():  # calculate data from tags__dirs_with_different_atime_and_mtime_age
            yield row

        for row in calculate_result_from_explicit_tags_on_files(trie, vol_id, include_zone_tag_name_ids):  # returns files from tags__entries_with_explicit_tags
            yield row

    mt.print_all_stats(log)
    log(f"====================== END {flavour}_plpython =========================")

$$ LANGUAGE plpython3u SECURITY DEFINER VOLATILE PARALLEL UNSAFE;
