Files
billai/analyzer/cleaners/base.py
cheliangzhao e2e1beb6f7 feat: implement cross-batch Alipay refund reconciliation
When a refund row in an uploaded Alipay bill has no matching expense
row in the same batch (because the original purchase was uploaded in a
prior batch), the refund is now reconciled against the stored record in
bills_cleaned rather than being silently discarded.

Changes:
- analyzer/cleaners/base.py: add unresolved_refunds list to BaseCleaner
- analyzer/cleaners/alipay.py: _aggregate_refunds stores full refund
  metadata (dict); _process_expenses tracks matched keys and populates
  self.unresolved_refunds for unmatched refunds
- analyzer/server.py: thread unresolved_refunds through do_clean,
  CleanResponse, and both /clean endpoints
- server/adapter/adapter.go: add UnresolvedRefund type and field to CleanResult
- server/adapter/http/cleaner.go: deserialize unresolved_refunds from
  Python response and populate CleanResult
- server/repository/repository.go: add ReconcileRefund to BillRepository interface
- server/repository/mongo/repository.go: implement ReconcileRefund —
  full refund soft-deletes the bill, partial refund reduces amount and
  appends remark with original amount and refund order number
- server/handler/upload.go: capture clean result and call ReconcileRefund
  for each unresolved refund after saving cleaned bills
- server/model/response.go: add ReconciledRefundCount to UploadData

Also: add CLAUDE.md (@AGENTS.md), update AGENTS.md, fix DailyTrendChart
missing-date gap by filling zero-expense dates in daily map.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-16 19:29:47 +08:00

297 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
账单清理基类和公共工具函数
"""
import csv
import json
import argparse
from abc import ABC, abstractmethod
from datetime import datetime, date, timedelta
from decimal import Decimal, ROUND_HALF_UP
from pathlib import Path
# =============================================================================
# 公共工具函数
# =============================================================================
def parse_date(date_str: str) -> date:
"""解析日期字符串,支持 YYYY-MM-DD 或 YYYY/MM/DD 格式"""
for fmt in ("%Y-%m-%d", "%Y/%m/%d"):
try:
return datetime.strptime(date_str, fmt).date()
except ValueError:
continue
raise ValueError(f"无法解析日期: {date_str},请使用 YYYY-MM-DD 格式")
def parse_amount(amount_str: str) -> Decimal:
"""解析金额字符串为Decimal去掉¥/¥符号)"""
try:
# 同时处理全角¥和半角¥
clean = amount_str.replace("", "").replace("¥", "").replace(" ", "").strip()
return Decimal(clean)
except:
return Decimal("0")
def format_amount(amount: Decimal) -> str:
"""格式化金额为字符串(保留两位小数)"""
return str(amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP))
def compute_date_range(args) -> tuple[date | None, date | None]:
"""
根据参数计算最终的日期范围
多重指定时取交集(最小范围)
Returns:
(start_date, end_date) 或 (None, None) 表示不筛选
"""
start_date = None
end_date = None
# 1. 根据年份设置范围
if args.year:
year = int(args.year)
start_date = date(year, 1, 1)
end_date = date(year, 12, 31)
# 2. 根据月份进一步收窄
if args.month:
month = int(args.month)
year = int(args.year) if args.year else datetime.now().year
if not start_date:
start_date = date(year, 1, 1)
end_date = date(year, 12, 31)
month_start = date(year, month, 1)
if month == 12:
month_end = date(year, 12, 31)
else:
month_end = date(year, month + 1, 1) - timedelta(days=1)
start_date = max(start_date, month_start) if start_date else month_start
end_date = min(end_date, month_end) if end_date else month_end
# 3. 根据 start/end 参数进一步收窄
if args.start:
custom_start = parse_date(args.start)
start_date = max(start_date, custom_start) if start_date else custom_start
if args.end:
custom_end = parse_date(args.end)
end_date = min(end_date, custom_end) if end_date else custom_end
return start_date, end_date
def compute_date_range_from_values(
year: str = None,
month: str = None,
start: str = None,
end: str = None
) -> tuple[date | None, date | None]:
"""
根据参数值计算日期范围(不依赖 argparse
供 HTTP API 调用使用
Returns:
(start_date, end_date) 或 (None, None) 表示不筛选
"""
start_date = None
end_date = None
# 1. 根据年份设置范围
if year:
y = int(year)
start_date = date(y, 1, 1)
end_date = date(y, 12, 31)
# 2. 根据月份进一步收窄
if month:
m = int(month)
y = int(year) if year else datetime.now().year
if not start_date:
start_date = date(y, 1, 1)
end_date = date(y, 12, 31)
month_start = date(y, m, 1)
if m == 12:
month_end = date(y, 12, 31)
else:
month_end = date(y, m + 1, 1) - timedelta(days=1)
start_date = max(start_date, month_start) if start_date else month_start
end_date = min(end_date, month_end) if end_date else month_end
# 3. 根据 start/end 参数进一步收窄
if start:
custom_start = parse_date(start)
start_date = max(start_date, custom_start) if start_date else custom_start
if end:
custom_end = parse_date(end)
end_date = min(end_date, custom_end) if end_date else custom_end
return start_date, end_date
def is_in_date_range(date_str: str, start_date: date | None, end_date: date | None) -> bool:
"""检查日期字符串是否在指定范围内"""
if start_date is None and end_date is None:
return True
try:
row_date = datetime.strptime(date_str[:10], "%Y-%m-%d").date()
except ValueError:
return False
if start_date and row_date < start_date:
return False
if end_date and row_date > end_date:
return False
return True
def create_arg_parser(description: str) -> argparse.ArgumentParser:
"""创建通用的命令行参数解析器"""
parser = argparse.ArgumentParser(
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
日期筛选说明:
--year 指定年份(如 2026
--month 指定月份1-12
--start 起始日期YYYY-MM-DD
--end 结束日期YYYY-MM-DD
多个条件同时指定时,取交集(最小日期范围)
输出格式:
--format 输出格式csv默认或 json
"""
)
parser.add_argument("input_file", help="输入的账单CSV文件")
parser.add_argument("output_file", nargs="?", default=None,
help="输出文件(默认为 输入文件名_cleaned.csv/json")
parser.add_argument("--year", "-y", type=str, default=None,
help="保留的年份(如 2026")
parser.add_argument("--month", "-m", type=int, choices=range(1, 13),
metavar="1-12", help="保留的月份1-12")
parser.add_argument("--start", "-s", type=str, help="起始日期YYYY-MM-DD")
parser.add_argument("--end", "-e", type=str, help="结束日期YYYY-MM-DD")
parser.add_argument("--format", "-f", choices=["csv", "json"], default="csv",
help="输出格式csv默认或 json")
return parser
def get_output_file(input_file: str, output_file: str | None, output_format: str = "csv") -> str:
"""获取输出文件路径"""
if output_file:
return output_file
import os
base_name = os.path.splitext(input_file)[0]
ext = "json" if output_format == "json" else "csv"
return f"{base_name}_cleaned.{ext}"
# =============================================================================
# 账单清理基类
# =============================================================================
class BaseCleaner(ABC):
"""账单清理基类"""
def __init__(self, input_file: str, output_file: str | None = None, output_format: str = "csv"):
self.input_file = input_file
self.output_format = output_format
self.output_file = get_output_file(input_file, output_file, output_format)
self.start_date: date | None = None
self.end_date: date | None = None
# 统计信息
self.stats = {
"original_count": 0,
"filtered_count": 0,
"fully_refunded": 0,
"partially_refunded": 0,
"category_adjusted": 0,
"final_count": 0,
}
# 本次清理中未能在同批次内匹配到对应支出的退款(跨批次核销用)
self.unresolved_refunds: list[dict] = []
def set_date_range(self, start_date: date | None, end_date: date | None):
"""设置日期筛选范围"""
self.start_date = start_date
self.end_date = end_date
def print_header(self):
"""打印处理头信息"""
print(f"输入文件: {self.input_file}")
print(f"输出文件: {self.output_file}")
print(f"输出格式: {self.output_format.upper()}")
if self.start_date or self.end_date:
print(f"日期范围: {self.start_date or '不限'} ~ {self.end_date or '不限'}")
else:
print("日期范围: 全部")
print()
def write_output(self, header: list, rows: list):
"""
写入输出文件(支持 CSV 和 JSON 格式)
Args:
header: 表头列表
rows: 数据行列表
"""
if self.output_format == "json":
self._write_json(header, rows)
else:
self._write_csv(header, rows)
def _write_csv(self, header: list, rows: list):
"""写入 CSV 格式"""
with open(self.output_file, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(rows)
def _write_json(self, header: list, rows: list):
"""写入 JSON 格式"""
# 将每行转换为字典
data = []
for row in rows:
record = {}
for i, col in enumerate(header):
if i < len(row):
record[col] = row[i]
else:
record[col] = ""
data.append(record)
with open(self.output_file, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
@abstractmethod
def clean(self) -> None:
"""执行清理,子类实现"""
pass
@abstractmethod
def reclassify(self, rows: list) -> list:
"""
重新分类(子类实现)
Args:
rows: 待处理的数据行
Returns:
处理后的数据行
"""
pass