From a032e1e0b11365a0dc5d725fd234771cd53c0858 Mon Sep 17 00:00:00 2001 From: gongzt Date: Fri, 2 Jun 2023 14:29:57 +0800 Subject: [PATCH] Repair Host cve verification is not performed in a generation task MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apollo/database/proxy/task.py | 37 +++++++++++++++++++++++++++++ apollo/handler/task_handler/view.py | 20 ++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/apollo/database/proxy/task.py b/apollo/database/proxy/task.py index e660f02..edba161 100644 --- a/apollo/database/proxy/task.py +++ b/apollo/database/proxy/task.py @@ -3208,3 +3208,40 @@ class TaskProxy(TaskMysqlProxy, TaskEsProxy): # insert task id and username into es self._init_task_in_es(task_id, data["username"]) + + def validate_cves(self, cve_id: list) -> bool: + """ + Verifying cve validity + + Args: + cve_id: id of the cve to be validate + + Returns: + bool: A return of true indicates that the validation passed + """ + + try: + exists_cve_count = self.session.query(CveHostAssociation).filter( + CveHostAssociation.cve_id.in_(cve_id)).count() + + return True if exists_cve_count == len(cve_id) else False + except SQLAlchemyError as error: + LOGGER.error(error) + return False + + def validate_hosts(self, host_id: list) -> bool: + """ + Verifying host validity + + Args: + host_id: id of the host to be validate + + Returns: + bool: A return of true indicates that the validation passed + """ + try: + exists_host_count = self.session.query(Host).filter(Host.host_id.in_(host_id)).count() + return True if exists_host_count == len(host_id) else False + except SQLAlchemyError as error: + LOGGER.error(error) + return False diff --git a/apollo/handler/task_handler/view.py b/apollo/handler/task_handler/view.py index 214053c..314f7bb 100644 --- a/apollo/handler/task_handler/view.py +++ b/apollo/handler/task_handler/view.py @@ -287,6 +287,14 @@ class VulGenerateCveTask(BaseResponse): "task_id": "id1" } """ + host_ids = [host["host_id"] for hosts in params["info"] for host in hosts["host_info"]] + if not callback.validate_hosts(host_id=list(set(host_ids))): + return self.response(code=PARAM_ERROR) + + cve_ids = [cve["cve_id"] for cve in params["info"]] + if not callback.validate_cves(cve_id=list(set(cve_ids))): + return self.response(code=PARAM_ERROR) + status_code, data = self._handle(callback, params) return self.response(code=status_code, data=data) @@ -488,6 +496,10 @@ class VulGenerateRepoTask(BaseResponse): "task_id": "1" } """ + host_ids = [host["host_id"] for host in params["info"]] + if not callback.validate_hosts(host_id=list(set(host_ids))): + return self.response(code=PARAM_ERROR) + status_code, data = self._handle(callback, params) return self.response(code=status_code, data=data) @@ -836,6 +848,14 @@ class VulGenerateCveRollback(BaseResponse): "task_id": "1" } """ + host_ids = [host["host_id"] for host in params["info"]] + if not callback.validate_hosts(host_id=list(set(host_ids))): + return self.response(code=PARAM_ERROR) + + cve_ids = [cve["cve_id"] for host in params["info"] for cve in host["cves"]] + if not callback.validate_cves(cve_id=list(set(cve_ids))): + return self.response(code=PARAM_ERROR) + status_code, data = self._handle(callback, params) return self.response(code=status_code, data=data) -- Gitee