From 8bfb66a3f9a6e1293b7cc4d72cc02e455be9cea9 Mon Sep 17 00:00:00 2001 From: rabbitali Date: Thu, 8 Jun 2023 10:39:45 +0800 Subject: [PATCH] fix issue: hotpatch status filter exception --- apollo/database/proxy/cve.py | 22 +++++++++++++++------- apollo/database/proxy/host.py | 22 +++++++++++++++++----- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/apollo/database/proxy/cve.py b/apollo/database/proxy/cve.py index 13a1ae6..24245de 100644 --- a/apollo/database/proxy/cve.py +++ b/apollo/database/proxy/cve.py @@ -187,8 +187,7 @@ class CveMysqlProxy(MysqlProxy): cve_id = data["cve_id"] filters = self._get_cve_hosts_filters(data.get("filter", {})) - cve_hosts_query = self._query_cve_hosts( - data["username"], cve_id, filters) + cve_hosts_query = self._query_cve_hosts(data["username"], cve_id, filters, data.get("filter", {})) total_count = cve_hosts_query.count() if not total_count: @@ -238,33 +237,42 @@ class CveMysqlProxy(MysqlProxy): filters.add(Host.host_group_name.in_(filter_dict["host_group"])) if filter_dict.get("repo"): filters.add(Host.repo_name.in_(filter_dict["repo"])) - if filter_dict.get("hp_status"): - filters.add(CveHostAssociation.hp_status.in_(filter_dict["hp_status"])) + if filter_dict.get("hotpatch") and fixed is True: filters.add(CveHostAssociation.fixed_by_hp.in_(filter_dict["hotpatch"])) elif filter_dict.get("hotpatch") and fixed is False: filters.add(CveHostAssociation.support_hp.in_(filter_dict["hotpatch"])) return filters - def _query_cve_hosts(self, username, cve_id, filters): + def _query_cve_hosts(self, username: str, cve_id: str, filters: set, filter_dict: dict): """ query needed cve hosts info Args: username (str): user name of the request cve_id (str): cve id filters (set): filter given by user - + filter_dict { + "fixed": bool, + "hotpatch": [true, false], + "hp_status": [accepted, active] + } Returns: sqlalchemy.orm.query.Query """ cve_query = self.session.query(Host.host_id, Host.host_name, Host.host_ip, Host.host_group_name, Host.repo_name, Host.last_scan, CveHostAssociation.support_hp, CveHostAssociation.fixed, CveHostAssociation.fixed_by_hp, - CveHostAssociation.hp_status ) \ + CveHostAssociation.hp_status) \ .join(CveHostAssociation, Host.host_id == CveHostAssociation.host_id) \ .filter(Host.user == username, CveHostAssociation.cve_id == cve_id) \ .filter(*filters) + if filter_dict.get("fixed"): + if filter_dict.get("hotpatch") == [True] and filter_dict.get("hp_status"): + return cve_query.filter(CveHostAssociation.hp_status.in_(filter_dict["hp_status"])) + elif len(filter_dict.get("hotpatch")) != 1 and filter_dict.get("hp_status"): + return cve_query.filter(CveHostAssociation.hp_status.in_(filter_dict["hp_status"]), + CveHostAssociation.fixed_by_hp == True).union(cve_query.filter(CveHostAssociation.fixed_by_hp == False)) return cve_query @staticmethod diff --git a/apollo/database/proxy/host.py b/apollo/database/proxy/host.py index 3fdf97b..bc30288 100644 --- a/apollo/database/proxy/host.py +++ b/apollo/database/proxy/host.py @@ -475,7 +475,7 @@ class HostProxy(HostMysqlProxy, CveEsProxy): host_id = data["host_id"] filters = self._get_host_cve_filters(data.get("filter", {})) host_cve_query = self._query_host_cve( - data["username"], host_id, filters) + data["username"], host_id, filters, data.get("filter", {})) total_count = host_cve_query.count() if not total_count: @@ -514,6 +514,8 @@ class HostProxy(HostMysqlProxy, CveEsProxy): Returns: set """ + # when fixed does not have a value, the query data is not meaningful + # the default query is unfixed CVE information fixed = filter_dict.get("fixed", False) filters = {CveHostAssociation.fixed == fixed} @@ -525,8 +527,6 @@ class HostProxy(HostMysqlProxy, CveEsProxy): "%" + filter_dict["cve_id"] + "%")) if filter_dict.get("severity"): filters.add(Cve.severity.in_(filter_dict["severity"])) - if filter_dict.get("hp_status"): - filters.add(CveHostAssociation.hp_status.in_(filter_dict["hp_status"])) if filter_dict.get("hotpatch") and fixed is True: filters.add(CveHostAssociation.fixed_by_hp.in_(filter_dict["hotpatch"])) elif filter_dict.get("hotpatch") and fixed is False: @@ -536,17 +536,22 @@ class HostProxy(HostMysqlProxy, CveEsProxy): filters.add(CveHostAssociation.affected == filter_dict["affected"]) return filters - def _query_host_cve(self, username, host_id, filters): + def _query_host_cve(self, username: str, host_id: int, filters: set, filter_dict: dict): """ query needed host CVEs info Args: username (str): user name of the request host_id (int): host id filters (set): filter given by user - + filter_dict { + "fixed": bool, + "hotpatch": [true, false], + "hp_status": [accepted, active] + } Returns: sqlalchemy.orm.query.Query """ + host_cve_query = self.session.query(CveHostAssociation.cve_id, Cve.publish_time, Cve.severity, Cve.cvss_score, CveHostAssociation.fixed, CveHostAssociation.support_hp, CveHostAssociation.fixed_by_hp, CveHostAssociation.hp_status) \ @@ -556,6 +561,13 @@ class HostProxy(HostMysqlProxy, CveEsProxy): .filter(CveHostAssociation.host_id == host_id, Host.user == username) \ .filter(*filters) + if filter_dict.get("fixed"): + if filter_dict.get("hotpatch") == [True] and filter_dict.get("hp_status"): + return host_cve_query.filter(CveHostAssociation.hp_status.in_(filter_dict["hp_status"])) + + elif len(filter_dict.get("hotpatch")) != 1 and filter_dict.get("hp_status"): + return host_cve_query.filter(CveHostAssociation.hp_status.in_(filter_dict["hp_status"]), + CveHostAssociation.fixed_by_hp == True).union(host_cve_query.filter(CveHostAssociation.fixed_by_hp == False)) return host_cve_query @staticmethod -- Gitee