From 8e19d92b6a484ddcf7ca7bf666ce21baa56ab326 Mon Sep 17 00:00:00 2001 From: rearcher <123781007@qq.com> Date: Wed, 20 Dec 2023 17:21:22 +0800 Subject: [PATCH] fix TimedCorrectTask --- apollo/cron/timed_correct_manager.py | 12 ++++-- apollo/database/proxy/task/base.py | 50 ++++++++++++++++++----- apollo/database/proxy/task/timed_proxy.py | 8 ++++ apollo/tests/database/test_task.py | 2 +- 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/apollo/cron/timed_correct_manager.py b/apollo/cron/timed_correct_manager.py index db0b4c1..ae3a1e3 100644 --- a/apollo/cron/timed_correct_manager.py +++ b/apollo/cron/timed_correct_manager.py @@ -41,11 +41,15 @@ class TimedCorrectTask(TimedTask): """ Start the correct after the specified time of day. """ - LOGGER.info("Begin to correct the whole host in %s.", str(datetime.datetime.now())) + LOGGER.info( + "Begin to correct the status of timeout tasks and scan timeout host in %s.", + str(datetime.datetime.now())) abnormal_task_ids, abnormal_host_ids = self.get_abnormal_task() - self._update_host_status(abnormal_host_ids) - with TimedProxy() as proxy: - proxy.timed_correct_error_task_status(abnormal_task_ids) + if len(abnormal_host_ids) != 0: + self._update_host_status(abnormal_host_ids) + if len(abnormal_task_ids) != 0: + with TimedProxy() as proxy: + proxy.timed_correct_error_task_status(abnormal_task_ids) @staticmethod def _abnormal_task(tasks): diff --git a/apollo/database/proxy/task/base.py b/apollo/database/proxy/task/base.py index 840c140..a5ddede 100644 --- a/apollo/database/proxy/task/base.py +++ b/apollo/database/proxy/task/base.py @@ -861,17 +861,17 @@ class TaskProxy(TaskMysqlProxy, TaskEsProxy): raise EsOperationError("Delete task from elasticsearch failed due to internal error.") - def get_running_task_form_task_cve_host(self) -> list: + def get_running_task_form_hotpatch_remove_task(self) -> list: """ - Get all CVE repair tasks with running status under Username + Get all hotpatch remove tasks with running status under Username Returns: list: task id list """ - task_cve_query = ( - self.session.query(HotpatchRemoveTask).filter(HotpatchRemoveTask.status == TaskStatus.RUNNING).all() + hotpatch_remove_query = ( + self.session.query(HotpatchRemoveTask.task_id).filter(HotpatchRemoveTask.status == TaskStatus.RUNNING).all() ) - task_id_list = [task.task_id for task in task_cve_query] + task_id_list = [task.task_id for task in hotpatch_remove_query] return task_id_list def get_running_task_form_task_host_repo(self) -> list: @@ -882,13 +882,39 @@ class TaskProxy(TaskMysqlProxy, TaskEsProxy): list: task id list """ host_repo_query = ( - self.session.query(TaskHostRepoAssociation) + self.session.query(TaskHostRepoAssociation.task_id) .filter(TaskHostRepoAssociation.status == TaskStatus.RUNNING) .all() ) task_id_list = [task.task_id for task in host_repo_query] return task_id_list + def get_running_task_form_cve_fix_task(self) -> list: + """ + Get all CVE fix tasks with running status + + Returns: + list: task id list + """ + cve_fix_query = ( + self.session.query(CveFixTask.task_id).filter(CveFixTask.status == TaskStatus.RUNNING).all() + ) + task_id_list = [task.task_id for task in cve_fix_query] + return task_id_list + + def get_running_task_form_cve_rollback_task(self) -> list: + """ + Get all CVE rollback tasks with running status + + Returns: + list: task id list + """ + cve_rollback_query = ( + self.session.query(CveRollbackTask.task_id).filter(CveRollbackTask.status == TaskStatus.RUNNING).all() + ) + task_id_list = [task.task_id for task in cve_rollback_query] + return task_id_list + def get_scanning_status_and_time_from_host(self) -> list: """ Get all host id and time with scanning status from the host table @@ -907,13 +933,17 @@ class TaskProxy(TaskMysqlProxy, TaskEsProxy): Returns: list: Each element is a task information, including the task ID, task type, creation time """ - task_cve_id_list = self.get_running_task_form_task_cve_host() - task_repo_id_list = self.get_running_task_form_task_host_repo() host_info_list = self.get_scanning_status_and_time_from_host() - task_id_list = task_cve_id_list + task_repo_id_list + + task_cve_id_list = self.get_running_task_form_hotpatch_remove_task() + task_repo_id_list = self.get_running_task_form_task_host_repo() + task_cve_fix_list = self.get_running_task_form_cve_fix_task() + task_cve_rollback_list = self.get_running_task_form_cve_rollback_task() + + task_id_list = task_cve_id_list + task_repo_id_list + task_cve_fix_list + task_cve_rollback_list task_query = self.session.query(Task).filter(Task.task_id.in_(task_id_list)).all() - running_task_list = [(task.task_id, task.create_time) for task in task_query] + running_task_list = [(task.task_id, task.latest_execute_time) for task in task_query] return running_task_list, host_info_list def validate_cves(self, cve_id: list) -> bool: diff --git a/apollo/database/proxy/task/timed_proxy.py b/apollo/database/proxy/task/timed_proxy.py index 436c3bd..fd396d1 100644 --- a/apollo/database/proxy/task/timed_proxy.py +++ b/apollo/database/proxy/task/timed_proxy.py @@ -22,6 +22,8 @@ from apollo.conf.constant import TaskStatus from apollo.database.table import ( HotpatchRemoveTask, TaskHostRepoAssociation, + CveFixTask, + CveRollbackTask, ) @@ -42,6 +44,12 @@ class TimedProxy(MysqlProxy): self.session.query(TaskHostRepoAssociation).filter(TaskHostRepoAssociation.task_id.in_(task_ids)).update( {TaskHostRepoAssociation.status: TaskStatus.UNKNOWN}, synchronize_session=False ) + self.session.query(CveFixTask).filter(CveFixTask.task_id.in_(task_ids)).update( + {CveFixTask.status: TaskStatus.UNKNOWN}, synchronize_session=False + ) + self.session.query(CveRollbackTask).filter(CveRollbackTask.task_id.in_(task_ids)).update( + {CveRollbackTask.status: TaskStatus.UNKNOWN}, synchronize_session=False + ) self.session.commit() except SQLAlchemyError as error: self.session.rollback() diff --git a/apollo/tests/database/test_task.py b/apollo/tests/database/test_task.py index 35d923b..ceb84ab 100644 --- a/apollo/tests/database/test_task.py +++ b/apollo/tests/database/test_task.py @@ -401,7 +401,7 @@ class TestTaskMysqlFirst(DatabaseTestCase): def test_get_running_task_form_task_cve_host(self): self.assertEqual( - self.task_database.get_running_task_form_task_cve_host(), + self.task_database.get_running_task_form_hotpatch_remove_task(), ["1111111111poiuytrewqasdfghjklmnb"], ) -- 2.33.0