让我用一个数据分析项目的例子来展示plan-and-execute框架的应用。这个例子会涉及数据处理、分析和可视化等任务。
from typing import List, Dict, Any
from dataclasses import dataclass
import json
from enum import Enum
import logging
from datetime import datetime
# 任务状态枚举
class TaskStatus(Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
# 任务优先级枚举
class TaskPriority(Enum):
LOW = 1
MEDIUM = 2
HIGH = 3
# 任务定义
@dataclass
class Task:
id: str
name: str
description: str
priority: TaskPriority
dependencies: List[str] # 依赖的任务ID列表
status: TaskStatus
result: Any = None
error: str = None
# 工作流执行器
class WorkflowExecutor:
def __init__(self):
self.tasks = {}
self.logger = logging.getLogger(__name__)
def add_task(self, task: Task):
self.tasks[task.id] = task
def get_ready_tasks(self) -> List[Task]:
"""获取所有依赖已满足的待执行任务"""
ready_tasks = []
for task in self.tasks.values():
if task.status == TaskStatus.PENDING:
dependencies_met = all(
self.tasks[dep_id].status == TaskStatus.COMPLETED
for dep_id in task.dependencies
)
if dependencies_met:
ready_tasks.append(task)
return sorted(ready_tasks, key=lambda x: x.priority.value, reverse=True)
def execute_task(self, task: Task):
"""执行单个任务"""
task.status = TaskStatus.RUNNING
try:
# 这里实现具体任务的执行逻辑
result = self.task_handlers[task.id](
task,
{dep: self.tasks[dep].result for dep in task.dependencies}
)
task.result = result
task.status = TaskStatus.COMPLETED
except Exception as e:
task.status = TaskStatus.FAILED
task.error = str(e)
self.logger.error(f"Task {task.id} failed: {e}")
def execute_workflow(self):
"""执行整个工作流"""
while True:
ready_tasks = self.get_ready_tasks()
if not ready_tasks:
break
for task in ready_tasks:
self.execute_task(task)
# 检查是否所有任务都完成
all_completed = all(
task.status == TaskStatus.COMPLETED
for task in self.tasks.values()
)
return all_completed
# 数据分析工作流示例
class DataAnalysisWorkflow:
def __init__(self, data_path: str, output_path: str):
self.data_path = data_path
self.output_path = output_path
self.executor = WorkflowExecutor()
def plan_workflow(self):
"""规划工作流程"""
tasks = [
Task(
id="load_data",
name="加载数据",
description="从CSV文件加载数据",
priority=TaskPriority.HIGH,
dependencies=[],
status=TaskStatus.PENDING
),
Task(
id="clean_data",
name="数据清洗",
description="处理缺失值和异常值",
priority=TaskPriority.HIGH,
dependencies=["load_data"],
status=TaskStatus.PENDING
),
Task(
id="feature_engineering",
name="特征工程",
description="创建新特征",
priority=TaskPriority.MEDIUM,
dependencies=["clean_data"],
status=TaskStatus.PENDING
),
Task(
id="statistical_analysis",
name="统计分析",
description="计算基本统计指标",
priority=TaskPriority.MEDIUM,
dependencies=["clean_data"],
status=TaskStatus.PENDING
),
Task(
id="visualization",
name="数据可视化",
description="生成图表",
priority=TaskPriority.MEDIUM,
dependencies=["statistical_analysis"],
status=TaskStatus.PENDING
),
Task(
id="generate_report",
name="生成报告",
description="生成分析报告",
priority=TaskPriority.LOW,
dependencies=["visualization", "feature_engineering"],
status=TaskStatus.PENDING
)
]
for task in tasks:
self.executor.add_task(task)
def register_task_handlers(self):
"""注册任务处理函数"""
self.executor.task_handlers = {
"load_data": self.load_data,
"clean_data": self.clean_data,
"feature_engineering": self.feature_engineering,
"statistical_analysis": self.statistical_analysis,
"visualization": self.visualization,
"generate_report": self.generate_report
}
def load_data(self, task: Task, dependencies: Dict):
import pandas as pd
df = pd.read_csv(self.data_path)
return df
def clean_data(self, task: Task, dependencies: Dict):
df = dependencies["load_data"]
# 处理缺失值
df = df.fillna(df.mean())
# 处理异常值
# ... 其他清洗逻辑
return df
def feature_engineering(self, task: Task, dependencies: Dict):
df = dependencies["clean_data"]
# 创建新特征
# ... 特征工程逻辑
return df
def statistical_analysis(self, task: Task, dependencies: Dict):
df = dependencies["clean_data"]
stats = {
"basic_stats": df.describe(),
"correlations": df.corr(),
# ... 其他统计分析
}
return stats
def visualization(self, task: Task, dependencies: Dict):
import matplotlib.pyplot as plt
stats = dependencies["statistical_analysis"]
figures = []
# 生成可视化
# ... 可视化逻辑
return figures
def generate_report(self, task: Task, dependencies: Dict):
figures = dependencies["visualization"]
df_features = dependencies["feature_engineering"]
report = {
"timestamp": datetime.now().isoformat(),
"statistics": str(dependencies["statistical_analysis"]),
"features": df_features.columns.tolist(),
"figures": [f.to_json() for f in figures]
}
# 保存报告
with open(f"{self.output_path}/report.json", "w") as f:
json.dump(report, f, indent=2)
return report
def run(self):
"""运行完整的工作流"""
self.plan_workflow()
self.register_task_handlers()
success = self.executor.execute_workflow()
if success:
final_report = self.executor.tasks["generate_report"].result
print("工作流执行成功!")
return final_report
else:
failed_tasks = [
task for task in self.executor.tasks.values()
if task.status == TaskStatus.FAILED
]
print("工作流执行失败。失败的任务:")
for task in failed_tasks:
print(f"- {task.name}: {task.error}")
return None
# 使用示例
def main():
workflow = DataAnalysisWorkflow(
data_path="data/sales_data.csv",
output_path="output"
)
result = workflow.run()
if result:
print("分析报告已生成:", result)
else:
print("工作流执行失败")
if __name__ == "__main__":
main()
这个例子展示了:
- 工作流框架的核心组件:
- Task定义
- 工作流执行器
- 依赖管理
- 状态追踪
- 错误处理
- 实现的关键特性:
- 自动任务规划
- 依赖关系处理
- 并行任务执行
- 结果传递
- 错误恢复
- 可以扩展的方向:
# 1. 添加任务重试机制
class RetryableExecutor(WorkflowExecutor):
def execute_task(self, task: Task, max_retries: int = 3):
retries = 0
while retries < max_retries:
try:
super().execute_task(task)
if task.status == TaskStatus.COMPLETED:
break
except Exception as e:
retries += 1
self.logger.warning(f"Retry {retries}/{max_retries} for task {task.id}")
# 2. 添加进度监控
class MonitoredWorkflow(DataAnalysisWorkflow):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.progress_callback = None
def set_progress_callback(self, callback):
self.progress_callback = callback
def update_progress(self, task: Task, status: str):
if self.progress_callback:
self.progress_callback(task, status)
# 3. 添加中间结果缓存
class CachedExecutor(WorkflowExecutor):
def __init__(self, cache_dir: str):
super().__init__()
self.cache_dir = cache_dir
def get_cached_result(self, task: Task):
cache_path = f"{self.cache_dir}/{task.id}.cache"
if os.path.exists(cache_path):
return pickle.load(open(cache_path, "rb"))
return None
def cache_result(self, task: Task):
cache_path = f"{self.cache_dir}/{task.id}.cache"
pickle.dump(task.result, open(cache_path, "wb"))
- 使用建议:
# 1. 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 2. 添加性能监控
from time import time
class PerformanceMonitor:
def __init__(self):
self.task_times = {}
def start_task(self, task_id: str):
self.task_times[task_id] = {"start": time()}
def end_task(self, task_id: str):
self.task_times[task_id]["end"] = time()
def get_task_duration(self, task_id: str):
times = self.task_times[task_id]
return times["end"] - times["start"]
# 3. 实现优雅的终止
import signal
class GracefulWorkflow(DataAnalysisWorkflow):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.should_stop = False
signal.signal(signal.SIGINT, self.handle_interrupt)
def handle_interrupt(self, signum, frame):
print("\nReceived interrupt signal. Cleaning up...")
self.should_stop = True
这个框架可以用于很多场景,比如:
- 数据处理管道
- ETL工作流
- 机器学习实验
- 报告生成系统
- 自动化测试流程
关键是要根据具体需求调整任务定义和执行逻辑。