aops-apollo/0001-fix-some-apis-which-has-filter-fault.patch
2023-06-01 20:53:46 +08:00

148 lines
6.7 KiB
Diff

From ca1388c59c97d31dbbdae3c48e7033dbc2d11b47 Mon Sep 17 00:00:00 2001
From: rabbitali <shusheng.wen@outlook.com>
Date: Mon, 29 May 2023 17:05:17 +0800
Subject: [PATCH] fix issue where some fields of the interface cannot support filtering
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
apollo/database/proxy/cve.py | 12 +++++-------
apollo/database/proxy/host.py | 7 ++++---
apollo/database/proxy/task.py | 14 ++++----------
apollo/function/schema/task.py | 3 ++-
4 files changed, 15 insertions(+), 21 deletions(-)
diff --git a/apollo/database/proxy/cve.py b/apollo/database/proxy/cve.py
index ed4c1d2..9dc96ae 100644
--- a/apollo/database/proxy/cve.py
+++ b/apollo/database/proxy/cve.py
@@ -867,17 +867,15 @@ class CveProxy(CveMysqlProxy, CveEsProxy):
exist_cve_query = self.session.query(CveHostAssociation.cve_id) \
.join(Host, Host.host_id == CveHostAssociation.host_id) \
.filter(Host.user == username, CveHostAssociation.affected == 1, CveHostAssociation.fixed == 0)
- # get first column value from tuple to list
- exist_cve_list = list(zip(*exist_cve_query))[0]
related_cve_query = self.session.query(CveAffectedPkgs.cve_id) \
- .filter(CveAffectedPkgs.package.in_(pkg_list)) \
- .filter(CveAffectedPkgs.cve_id.in_(exist_cve_list))
- related_cve = set(list(zip(*related_cve_query))[0])
+ .filter(CveAffectedPkgs.package.in_(pkg_list),CveAffectedPkgs.cve_id.in_(exist_cve_query.subquery())) \
+ .distinct()
- related_cve.remove(cve_id)
- return list(related_cve)
+ related_cve = [row[0] for row in related_cve_query.all() if row[0] != cve_id]
+ return related_cve
+
@staticmethod
def _cve_info_row2dict(row, description_dict, pkg_list):
"""
diff --git a/apollo/database/proxy/host.py b/apollo/database/proxy/host.py
index f3fe51e..a9431a9 100644
--- a/apollo/database/proxy/host.py
+++ b/apollo/database/proxy/host.py
@@ -514,7 +514,8 @@ class HostProxy(HostMysqlProxy, CveEsProxy):
Returns:
set
"""
- filters = {CveHostAssociation.fixed == filter_dict.get("fixed", False)}
+ fixed = filter_dict.get("fixed", False)
+ filters = {CveHostAssociation.fixed == fixed}
if not filter_dict:
return filters
@@ -525,9 +526,9 @@ class HostProxy(HostMysqlProxy, CveEsProxy):
if filter_dict.get("severity"):
filters.add(Cve.severity.in_(filter_dict["severity"]))
- if filter_dict.get("hotpatch") and filter_dict.get("fixed") is True:
+ if filter_dict.get("hotpatch") and fixed is True:
filters.add(CveHostAssociation.fixed_by_hp.in_(filter_dict["hotpatch"]))
- elif filter_dict.get("hotpatch") and filter_dict.get("fixed") is False:
+ elif filter_dict.get("hotpatch") and fixed is False:
filters.add(CveHostAssociation.support_hp.in_(filter_dict["hotpatch"]))
if "affected" in filter_dict:
diff --git a/apollo/database/proxy/task.py b/apollo/database/proxy/task.py
index f457043..e660f02 100644
--- a/apollo/database/proxy/task.py
+++ b/apollo/database/proxy/task.py
@@ -924,9 +924,7 @@ class TaskMysqlProxy(MysqlProxy):
filters = set()
if filter_dict.get("cve_id"):
- filters.add(Cve.cve_id.like("%" + filter_dict["cve_id"] + "%"))
- if filter_dict.get("reboot"):
- filters.add(Cve.reboot == filter_dict["reboot"])
+ filters.add(CveHostAssociation.cve_id.like("%" + filter_dict["cve_id"] + "%"))
return filters
def _query_cve_task(self, username, task_id, filters):
@@ -948,12 +946,11 @@ class TaskMysqlProxy(MysqlProxy):
}
"""
task_cve_query = self.session.query(TaskCveHostAssociation.cve_id,
- Cve.reboot,
CveAffectedPkgs.package,
TaskCveHostAssociation.host_id,
TaskCveHostAssociation.status) \
- .outerjoin(Cve, Cve.cve_id == TaskCveHostAssociation.cve_id) \
- .outerjoin(CveAffectedPkgs, CveAffectedPkgs.cve_id == Cve.cve_id) \
+ .outerjoin(CveHostAssociation, CveHostAssociation.cve_id == TaskCveHostAssociation.cve_id) \
+ .outerjoin(CveAffectedPkgs, CveAffectedPkgs.cve_id == CveHostAssociation.cve_id) \
.outerjoin(Task, Task.task_id == TaskCveHostAssociation.task_id) \
.filter(Task.task_id == task_id, Task.username == username) \
.filter(*filters)
@@ -969,7 +966,6 @@ class TaskMysqlProxy(MysqlProxy):
{
"cve_id": "CVE-2021-0001",
"package": "tensorflow",
- "reboot": True,
"host_id": "id1",
"status": "fixed"
}
@@ -979,7 +975,6 @@ class TaskMysqlProxy(MysqlProxy):
[{
"cve_id": "CVE-2021-0001",
"package": "tensorflow",
- "reboot": True,
"host_num": 3,
"status": "running"
}]
@@ -991,8 +986,7 @@ class TaskMysqlProxy(MysqlProxy):
for row in task_cve_query:
cve_id = row.cve_id
if cve_id not in cve_dict:
- cve_dict[cve_id] = {"package": {row.package}, "reboot": row.reboot,
- "host_set": {row.host_id}, "status_set": {row.status}}
+ cve_dict[cve_id] = {"package": {row.package}, "host_set": {row.host_id}, "status_set": {row.status}}
else:
cve_dict[cve_id]["package"].add(row.package)
cve_dict[cve_id]["host_set"].add(row.host_id)
diff --git a/apollo/function/schema/task.py b/apollo/function/schema/task.py
index 1fa776c..472fd53 100644
--- a/apollo/function/schema/task.py
+++ b/apollo/function/schema/task.py
@@ -19,6 +19,7 @@ from marshmallow import Schema
from marshmallow import fields
from marshmallow import validate
+from apollo.conf.constant import TaskType
class TaskListFilterSchema(Schema):
"""
@@ -26,7 +27,7 @@ class TaskListFilterSchema(Schema):
"""
task_name = fields.String(required=False, validate=lambda s: len(s) > 0)
task_type = fields.List(fields.String(
- validate=validate.OneOf(["cve fix", "repo set"])), required=False)
+ validate=validate.OneOf([getattr(TaskType,p) for p in dir(TaskType) if p.isupper()])), required=False)
class GetTaskListSchema(Schema):
--