""" 账单清理基类和公共工具函数 """ import csv import json import argparse from abc import ABC, abstractmethod from datetime import datetime, date, timedelta from decimal import Decimal, ROUND_HALF_UP from pathlib import Path # ============================================================================= # 公共工具函数 # ============================================================================= def parse_date(date_str: str) -> date: """解析日期字符串,支持 YYYY-MM-DD 或 YYYY/MM/DD 格式""" for fmt in ("%Y-%m-%d", "%Y/%m/%d"): try: return datetime.strptime(date_str, fmt).date() except ValueError: continue raise ValueError(f"无法解析日期: {date_str},请使用 YYYY-MM-DD 格式") def parse_amount(amount_str: str) -> Decimal: """解析金额字符串为Decimal(去掉¥/¥符号)""" try: # 同时处理全角¥和半角¥ clean = amount_str.replace("¥", "").replace("¥", "").replace(" ", "").strip() return Decimal(clean) except: return Decimal("0") def format_amount(amount: Decimal) -> str: """格式化金额为字符串(保留两位小数)""" return str(amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)) def compute_date_range(args) -> tuple[date | None, date | None]: """ 根据参数计算最终的日期范围 多重指定时取交集(最小范围) Returns: (start_date, end_date) 或 (None, None) 表示不筛选 """ start_date = None end_date = None # 1. 根据年份设置范围 if args.year: year = int(args.year) start_date = date(year, 1, 1) end_date = date(year, 12, 31) # 2. 根据月份进一步收窄 if args.month: month = int(args.month) year = int(args.year) if args.year else datetime.now().year if not start_date: start_date = date(year, 1, 1) end_date = date(year, 12, 31) month_start = date(year, month, 1) if month == 12: month_end = date(year, 12, 31) else: month_end = date(year, month + 1, 1) - timedelta(days=1) start_date = max(start_date, month_start) if start_date else month_start end_date = min(end_date, month_end) if end_date else month_end # 3. 根据 start/end 参数进一步收窄 if args.start: custom_start = parse_date(args.start) start_date = max(start_date, custom_start) if start_date else custom_start if args.end: custom_end = parse_date(args.end) end_date = min(end_date, custom_end) if end_date else custom_end return start_date, end_date def compute_date_range_from_values( year: str = None, month: str = None, start: str = None, end: str = None ) -> tuple[date | None, date | None]: """ 根据参数值计算日期范围(不依赖 argparse) 供 HTTP API 调用使用 Returns: (start_date, end_date) 或 (None, None) 表示不筛选 """ start_date = None end_date = None # 1. 根据年份设置范围 if year: y = int(year) start_date = date(y, 1, 1) end_date = date(y, 12, 31) # 2. 根据月份进一步收窄 if month: m = int(month) y = int(year) if year else datetime.now().year if not start_date: start_date = date(y, 1, 1) end_date = date(y, 12, 31) month_start = date(y, m, 1) if m == 12: month_end = date(y, 12, 31) else: month_end = date(y, m + 1, 1) - timedelta(days=1) start_date = max(start_date, month_start) if start_date else month_start end_date = min(end_date, month_end) if end_date else month_end # 3. 根据 start/end 参数进一步收窄 if start: custom_start = parse_date(start) start_date = max(start_date, custom_start) if start_date else custom_start if end: custom_end = parse_date(end) end_date = min(end_date, custom_end) if end_date else custom_end return start_date, end_date def is_in_date_range(date_str: str, start_date: date | None, end_date: date | None) -> bool: """检查日期字符串是否在指定范围内""" if start_date is None and end_date is None: return True try: row_date = datetime.strptime(date_str[:10], "%Y-%m-%d").date() except ValueError: return False if start_date and row_date < start_date: return False if end_date and row_date > end_date: return False return True def create_arg_parser(description: str) -> argparse.ArgumentParser: """创建通用的命令行参数解析器""" parser = argparse.ArgumentParser( description=description, formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 日期筛选说明: --year 指定年份(如 2026) --month 指定月份(1-12) --start 起始日期(YYYY-MM-DD) --end 结束日期(YYYY-MM-DD) 多个条件同时指定时,取交集(最小日期范围) 输出格式: --format 输出格式:csv(默认)或 json """ ) parser.add_argument("input_file", help="输入的账单CSV文件") parser.add_argument("output_file", nargs="?", default=None, help="输出文件(默认为 输入文件名_cleaned.csv/json)") parser.add_argument("--year", "-y", type=str, default=None, help="保留的年份(如 2026)") parser.add_argument("--month", "-m", type=int, choices=range(1, 13), metavar="1-12", help="保留的月份(1-12)") parser.add_argument("--start", "-s", type=str, help="起始日期(YYYY-MM-DD)") parser.add_argument("--end", "-e", type=str, help="结束日期(YYYY-MM-DD)") parser.add_argument("--format", "-f", choices=["csv", "json"], default="csv", help="输出格式:csv(默认)或 json") return parser def get_output_file(input_file: str, output_file: str | None, output_format: str = "csv") -> str: """获取输出文件路径""" if output_file: return output_file import os base_name = os.path.splitext(input_file)[0] ext = "json" if output_format == "json" else "csv" return f"{base_name}_cleaned.{ext}" # ============================================================================= # 账单清理基类 # ============================================================================= class BaseCleaner(ABC): """账单清理基类""" def __init__(self, input_file: str, output_file: str | None = None, output_format: str = "csv"): self.input_file = input_file self.output_format = output_format self.output_file = get_output_file(input_file, output_file, output_format) self.start_date: date | None = None self.end_date: date | None = None # 统计信息 self.stats = { "original_count": 0, "filtered_count": 0, "fully_refunded": 0, "partially_refunded": 0, "category_adjusted": 0, "final_count": 0, } def set_date_range(self, start_date: date | None, end_date: date | None): """设置日期筛选范围""" self.start_date = start_date self.end_date = end_date def print_header(self): """打印处理头信息""" print(f"输入文件: {self.input_file}") print(f"输出文件: {self.output_file}") print(f"输出格式: {self.output_format.upper()}") if self.start_date or self.end_date: print(f"日期范围: {self.start_date or '不限'} ~ {self.end_date or '不限'}") else: print("日期范围: 全部") print() def write_output(self, header: list, rows: list): """ 写入输出文件(支持 CSV 和 JSON 格式) Args: header: 表头列表 rows: 数据行列表 """ if self.output_format == "json": self._write_json(header, rows) else: self._write_csv(header, rows) def _write_csv(self, header: list, rows: list): """写入 CSV 格式""" with open(self.output_file, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow(header) writer.writerows(rows) def _write_json(self, header: list, rows: list): """写入 JSON 格式""" # 将每行转换为字典 data = [] for row in rows: record = {} for i, col in enumerate(header): if i < len(row): record[col] = row[i] else: record[col] = "" data.append(record) with open(self.output_file, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) @abstractmethod def clean(self) -> None: """执行清理,子类实现""" pass @abstractmethod def reclassify(self, rows: list) -> list: """ 重新分类(子类实现) Args: rows: 待处理的数据行 Returns: 处理后的数据行 """ pass