From ca1388c59c97d31dbbdae3c48e7033dbc2d11b47 Mon Sep 17 00:00:00 2001 From: rabbitali Date: Mon, 29 May 2023 17:05:17 +0800 Subject: [PATCH] fix issue where some fields of the interface cannot support filtering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apollo/database/proxy/cve.py | 12 +++++------- apollo/database/proxy/host.py | 7 ++++--- apollo/database/proxy/task.py | 14 ++++---------- apollo/function/schema/task.py | 3 ++- 4 files changed, 15 insertions(+), 21 deletions(-) diff --git a/apollo/database/proxy/cve.py b/apollo/database/proxy/cve.py index ed4c1d2..9dc96ae 100644 --- a/apollo/database/proxy/cve.py +++ b/apollo/database/proxy/cve.py @@ -867,17 +867,15 @@ class CveProxy(CveMysqlProxy, CveEsProxy): exist_cve_query = self.session.query(CveHostAssociation.cve_id) \ .join(Host, Host.host_id == CveHostAssociation.host_id) \ .filter(Host.user == username, CveHostAssociation.affected == 1, CveHostAssociation.fixed == 0) - # get first column value from tuple to list - exist_cve_list = list(zip(*exist_cve_query))[0] related_cve_query = self.session.query(CveAffectedPkgs.cve_id) \ - .filter(CveAffectedPkgs.package.in_(pkg_list)) \ - .filter(CveAffectedPkgs.cve_id.in_(exist_cve_list)) - related_cve = set(list(zip(*related_cve_query))[0]) + .filter(CveAffectedPkgs.package.in_(pkg_list),CveAffectedPkgs.cve_id.in_(exist_cve_query.subquery())) \ + .distinct() - related_cve.remove(cve_id) - return list(related_cve) + related_cve = [row[0] for row in related_cve_query.all() if row[0] != cve_id] + return related_cve + @staticmethod def _cve_info_row2dict(row, description_dict, pkg_list): """ diff --git a/apollo/database/proxy/host.py b/apollo/database/proxy/host.py index f3fe51e..a9431a9 100644 --- a/apollo/database/proxy/host.py +++ b/apollo/database/proxy/host.py @@ -514,7 +514,8 @@ class HostProxy(HostMysqlProxy, CveEsProxy): Returns: set """ - filters = {CveHostAssociation.fixed == filter_dict.get("fixed", False)} + fixed = filter_dict.get("fixed", False) + filters = {CveHostAssociation.fixed == fixed} if not filter_dict: return filters @@ -525,9 +526,9 @@ class HostProxy(HostMysqlProxy, CveEsProxy): if filter_dict.get("severity"): filters.add(Cve.severity.in_(filter_dict["severity"])) - if filter_dict.get("hotpatch") and filter_dict.get("fixed") is True: + if filter_dict.get("hotpatch") and fixed is True: filters.add(CveHostAssociation.fixed_by_hp.in_(filter_dict["hotpatch"])) - elif filter_dict.get("hotpatch") and filter_dict.get("fixed") is False: + elif filter_dict.get("hotpatch") and fixed is False: filters.add(CveHostAssociation.support_hp.in_(filter_dict["hotpatch"])) if "affected" in filter_dict: diff --git a/apollo/database/proxy/task.py b/apollo/database/proxy/task.py index f457043..e660f02 100644 --- a/apollo/database/proxy/task.py +++ b/apollo/database/proxy/task.py @@ -924,9 +924,7 @@ class TaskMysqlProxy(MysqlProxy): filters = set() if filter_dict.get("cve_id"): - filters.add(Cve.cve_id.like("%" + filter_dict["cve_id"] + "%")) - if filter_dict.get("reboot"): - filters.add(Cve.reboot == filter_dict["reboot"]) + filters.add(CveHostAssociation.cve_id.like("%" + filter_dict["cve_id"] + "%")) return filters def _query_cve_task(self, username, task_id, filters): @@ -948,12 +946,11 @@ class TaskMysqlProxy(MysqlProxy): } """ task_cve_query = self.session.query(TaskCveHostAssociation.cve_id, - Cve.reboot, CveAffectedPkgs.package, TaskCveHostAssociation.host_id, TaskCveHostAssociation.status) \ - .outerjoin(Cve, Cve.cve_id == TaskCveHostAssociation.cve_id) \ - .outerjoin(CveAffectedPkgs, CveAffectedPkgs.cve_id == Cve.cve_id) \ + .outerjoin(CveHostAssociation, CveHostAssociation.cve_id == TaskCveHostAssociation.cve_id) \ + .outerjoin(CveAffectedPkgs, CveAffectedPkgs.cve_id == CveHostAssociation.cve_id) \ .outerjoin(Task, Task.task_id == TaskCveHostAssociation.task_id) \ .filter(Task.task_id == task_id, Task.username == username) \ .filter(*filters) @@ -969,7 +966,6 @@ class TaskMysqlProxy(MysqlProxy): { "cve_id": "CVE-2021-0001", "package": "tensorflow", - "reboot": True, "host_id": "id1", "status": "fixed" } @@ -979,7 +975,6 @@ class TaskMysqlProxy(MysqlProxy): [{ "cve_id": "CVE-2021-0001", "package": "tensorflow", - "reboot": True, "host_num": 3, "status": "running" }] @@ -991,8 +986,7 @@ class TaskMysqlProxy(MysqlProxy): for row in task_cve_query: cve_id = row.cve_id if cve_id not in cve_dict: - cve_dict[cve_id] = {"package": {row.package}, "reboot": row.reboot, - "host_set": {row.host_id}, "status_set": {row.status}} + cve_dict[cve_id] = {"package": {row.package}, "host_set": {row.host_id}, "status_set": {row.status}} else: cve_dict[cve_id]["package"].add(row.package) cve_dict[cve_id]["host_set"].add(row.host_id) diff --git a/apollo/function/schema/task.py b/apollo/function/schema/task.py index 1fa776c..472fd53 100644 --- a/apollo/function/schema/task.py +++ b/apollo/function/schema/task.py @@ -19,6 +19,7 @@ from marshmallow import Schema from marshmallow import fields from marshmallow import validate +from apollo.conf.constant import TaskType class TaskListFilterSchema(Schema): """ @@ -26,7 +27,7 @@ class TaskListFilterSchema(Schema): """ task_name = fields.String(required=False, validate=lambda s: len(s) > 0) task_type = fields.List(fields.String( - validate=validate.OneOf(["cve fix", "repo set"])), required=False) + validate=validate.OneOf([getattr(TaskType,p) for p in dir(TaskType) if p.isupper()])), required=False) class GetTaskListSchema(Schema): --