Files
billai/analyzer/server.py
2026-01-23 14:17:59 +08:00

398 lines
11 KiB
Python

#!/usr/bin/env python3
"""
账单分析 FastAPI 服务
提供 HTTP API 供 Go 服务调用,替代子进程通信方式
"""
import os
import sys
import io
import tempfile
import shutil
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from typing import Optional
# 解决编码问题
if sys.stdout.encoding != 'utf-8':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
from cleaners.base import compute_date_range_from_values
from cleaners import AlipayCleaner, WechatCleaner
from category import infer_category, get_all_categories, get_all_income_categories
from converter import convert_bill_file
# 应用版本
APP_VERSION = "0.0.1"
# =============================================================================
# Pydantic 模型
# =============================================================================
class CleanRequest(BaseModel):
"""清洗请求"""
input_path: str
output_path: str
year: Optional[str] = None
month: Optional[str] = None
start: Optional[str] = None
end: Optional[str] = None
format: Optional[str] = "csv"
bill_type: Optional[str] = "auto" # auto, alipay, wechat
class CleanResponse(BaseModel):
"""清洗响应"""
success: bool
bill_type: str
message: str
output_path: Optional[str] = None
class CategoryRequest(BaseModel):
"""分类推断请求"""
merchant: str
product: str
income_expense: str # "收入" 或 "支出"
class CategoryResponse(BaseModel):
"""分类推断响应"""
category: str
is_certain: bool
class HealthResponse(BaseModel):
"""健康检查响应"""
status: str
version: str
class ConvertResponse(BaseModel):
"""文件转换响应"""
success: bool
bill_type: str
output_path: str
message: str
# =============================================================================
# 辅助函数
# =============================================================================
def detect_bill_type(filepath: str) -> str | None:
"""
检测账单类型
Returns:
'alipay' | 'wechat' | None
"""
try:
with open(filepath, "r", encoding="utf-8") as f:
for _ in range(50): # 支付宝账单可能有较多的头部信息行
line = f.readline()
if not line:
break
# 支付宝特征
if "交易分类" in line and "对方账号" in line:
return "alipay"
# 微信特征
if "交易类型" in line and "金额(元)" in line:
return "wechat"
# 数据行特征
if line.startswith("202"):
if "" in line:
return "wechat"
if "@" in line:
return "alipay"
except Exception as e:
print(f"读取文件失败: {e}", file=sys.stderr)
return None
return None
def do_clean(
input_path: str,
output_path: str,
bill_type: str = "auto",
year: str = None,
month: str = None,
start: str = None,
end: str = None,
output_format: str = "csv"
) -> tuple[bool, str, str]:
"""
执行清洗逻辑
Returns:
(success, bill_type, message)
"""
# 检查文件是否存在
if not Path(input_path).exists():
return False, "", f"文件不存在: {input_path}"
# 检测账单类型
if bill_type == "auto":
detected_type = detect_bill_type(input_path)
if detected_type is None:
return False, "", "无法识别账单类型"
bill_type = detected_type
# 计算日期范围
start_date, end_date = compute_date_range_from_values(year, month, start, end)
# 创建对应的清理器
try:
if bill_type == "alipay":
cleaner = AlipayCleaner(input_path, output_path, output_format)
else:
cleaner = WechatCleaner(input_path, output_path, output_format)
cleaner.set_date_range(start_date, end_date)
cleaner.clean()
type_names = {"alipay": "支付宝", "wechat": "微信"}
return True, bill_type, f"{type_names[bill_type]}账单清洗完成"
except Exception as e:
return False, bill_type, f"清洗失败: {str(e)}"
# =============================================================================
# FastAPI 应用
# =============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
print("🚀 账单分析服务启动")
yield
print("👋 账单分析服务关闭")
app = FastAPI(
title="BillAI Analyzer",
description="账单分析与清洗服务",
version="1.0.0",
lifespan=lifespan
)
# =============================================================================
# API 路由
# =============================================================================
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""健康检查"""
return HealthResponse(status="ok", version=APP_VERSION)
@app.post("/clean", response_model=CleanResponse)
async def clean_bill(request: CleanRequest):
"""
清洗账单文件
接收账单文件路径,执行清洗后输出到指定路径
"""
success, bill_type, message = do_clean(
input_path=request.input_path,
output_path=request.output_path,
bill_type=request.bill_type or "auto",
year=request.year,
month=request.month,
start=request.start,
end=request.end,
output_format=request.format or "csv"
)
if not success:
raise HTTPException(status_code=400, detail=message)
return CleanResponse(
success=True,
bill_type=bill_type,
message=message,
output_path=request.output_path
)
@app.post("/clean/upload", response_model=CleanResponse)
async def clean_bill_upload(
file: UploadFile = File(...),
year: Optional[str] = Form(None),
month: Optional[str] = Form(None),
start: Optional[str] = Form(None),
end: Optional[str] = Form(None),
format: Optional[str] = Form("csv"),
bill_type: Optional[str] = Form("auto")
):
"""
上传并清洗账单文件
通过 multipart/form-data 上传文件,清洗后返回结果
"""
# 创建临时文件
suffix = Path(file.filename).suffix or ".csv"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_input:
shutil.copyfileobj(file.file, tmp_input)
input_path = tmp_input.name
# 创建输出临时文件
output_suffix = ".json" if format == "json" else ".csv"
with tempfile.NamedTemporaryFile(delete=False, suffix=output_suffix) as tmp_output:
output_path = tmp_output.name
try:
success, detected_type, message = do_clean(
input_path=input_path,
output_path=output_path,
bill_type=bill_type or "auto",
year=year,
month=month,
start=start,
end=end,
output_format=format or "csv"
)
if not success:
raise HTTPException(status_code=400, detail=message)
return CleanResponse(
success=True,
bill_type=detected_type,
message=message,
output_path=output_path
)
finally:
# 清理输入临时文件
if os.path.exists(input_path):
os.unlink(input_path)
@app.get("/clean/download/{file_path:path}")
async def download_cleaned_file(file_path: str):
"""下载清洗后的文件"""
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="文件不存在")
return FileResponse(
file_path,
filename=Path(file_path).name,
media_type="application/octet-stream"
)
@app.post("/category/infer", response_model=CategoryResponse)
async def infer_category_api(request: CategoryRequest):
"""
推断交易分类
根据商户名称和商品信息推断交易分类
"""
category, is_certain = infer_category(
merchant=request.merchant,
product=request.product,
income_expense=request.income_expense
)
return CategoryResponse(category=category, is_certain=is_certain)
@app.get("/category/list")
async def list_categories():
"""获取所有分类列表"""
return {
"expense": get_all_categories(),
"income": get_all_income_categories()
}
@app.post("/detect")
async def detect_bill_type_api(file: UploadFile = File(...)):
"""
检测账单类型
上传文件后自动检测是支付宝还是微信账单
"""
suffix = Path(file.filename).suffix or ".csv"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_path = tmp.name
try:
bill_type = detect_bill_type(tmp_path)
if bill_type is None:
raise HTTPException(status_code=400, detail="无法识别账单类型")
type_names = {"alipay": "支付宝", "wechat": "微信"}
return {
"bill_type": bill_type,
"display_name": type_names[bill_type]
}
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
@app.post("/convert", response_model=ConvertResponse)
async def convert_bill_file_api(file: UploadFile = File(...)):
"""
转换账单文件格式
支持:
- xlsx -> csv 转换
- GBK/GB2312 -> UTF-8 编码转换
返回转换后的文件路径和检测到的账单类型
"""
# 保存上传的文件到临时位置
suffix = Path(file.filename).suffix or ".csv"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(file.file, tmp)
input_path = tmp.name
try:
# 调用转换函数
success, bill_type, output_path, message = convert_bill_file(input_path)
if not success:
raise HTTPException(status_code=400, detail=message)
return ConvertResponse(
success=True,
bill_type=bill_type,
output_path=output_path,
message=message
)
finally:
# 清理输入临时文件(转换后的输出文件由调用方负责清理)
if os.path.exists(input_path):
os.unlink(input_path)
# =============================================================================
# 启动入口
# =============================================================================
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("ANALYZER_PORT", 8001))
host = os.environ.get("ANALYZER_HOST", "0.0.0.0")
print(f"🚀 启动账单分析服务: http://{host}:{port}")
uvicorn.run(app, host=host, port=port)