Files
billai/server/repository/mongo/repository.go
cheliangzhao e2e1beb6f7 feat: implement cross-batch Alipay refund reconciliation
When a refund row in an uploaded Alipay bill has no matching expense
row in the same batch (because the original purchase was uploaded in a
prior batch), the refund is now reconciled against the stored record in
bills_cleaned rather than being silently discarded.

Changes:
- analyzer/cleaners/base.py: add unresolved_refunds list to BaseCleaner
- analyzer/cleaners/alipay.py: _aggregate_refunds stores full refund
  metadata (dict); _process_expenses tracks matched keys and populates
  self.unresolved_refunds for unmatched refunds
- analyzer/server.py: thread unresolved_refunds through do_clean,
  CleanResponse, and both /clean endpoints
- server/adapter/adapter.go: add UnresolvedRefund type and field to CleanResult
- server/adapter/http/cleaner.go: deserialize unresolved_refunds from
  Python response and populate CleanResult
- server/repository/repository.go: add ReconcileRefund to BillRepository interface
- server/repository/mongo/repository.go: implement ReconcileRefund —
  full refund soft-deletes the bill, partial refund reduces amount and
  appends remark with original amount and refund order number
- server/handler/upload.go: capture clean result and call ReconcileRefund
  for each unresolved refund after saving cleaned bills
- server/model/response.go: add ReconciledRefundCount to UploadData

Also: add CLAUDE.md (@AGENTS.md), update AGENTS.md, fix DailyTrendChart
missing-date gap by filling zero-expense dates in daily map.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-16 19:29:47 +08:00

595 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package mongo 实现基于 MongoDB 的账单存储
package mongo
import (
"context"
"fmt"
"math"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"billai-server/config"
"billai-server/model"
"billai-server/repository"
)
// refundEpsilon 退款核销后剩余金额的容差阈值,小于该值视为已全额退款
const refundEpsilon = 0.005
// Repository MongoDB 账单存储实现
type Repository struct {
client *mongo.Client
db *mongo.Database
rawCollection *mongo.Collection
cleanedCollection *mongo.Collection
}
// NewRepository 创建 MongoDB 存储实例
func NewRepository() *Repository {
return &Repository{}
}
// Connect 连接 MongoDB
func (r *Repository) Connect() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 创建客户端选项
clientOptions := options.Client().ApplyURI(config.Global.MongoURI)
// 连接 MongoDB
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return fmt.Errorf("连接 MongoDB 失败: %w", err)
}
// 测试连接
if err := client.Ping(ctx, nil); err != nil {
return fmt.Errorf("MongoDB Ping 失败: %w", err)
}
// 设置实例变量
r.client = client
r.db = client.Database(config.Global.MongoDatabase)
r.rawCollection = r.db.Collection(config.Global.MongoRawCollection)
r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection)
fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase)
fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection)
fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection)
return nil
}
// Disconnect 断开 MongoDB 连接
func (r *Repository) Disconnect() error {
if r.client == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := r.client.Disconnect(ctx); err != nil {
return fmt.Errorf("断开 MongoDB 连接失败: %w", err)
}
fmt.Println("🍃 MongoDB 连接已断开")
return nil
}
// SaveRawBills 保存原始账单数据
func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) {
if len(bills) == 0 {
return 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换为 interface{} 切片
docs := make([]interface{}, len(bills))
for i, bill := range bills {
docs[i] = bill
}
result, err := r.rawCollection.InsertMany(ctx, docs)
if err != nil {
return 0, fmt.Errorf("插入原始账单失败: %w", err)
}
return len(result.InsertedIDs), nil
}
// SaveCleanedBills 保存清洗后的账单数据(带去重)
func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) {
if len(bills) == 0 {
return 0, 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, bill := range bills {
// 检查是否重复
isDup, err := r.CheckCleanedDuplicate(&bill)
if err != nil {
return saved, duplicates, err
}
if isDup {
duplicates++
continue
}
// 插入新记录
_, err = r.cleanedCollection.InsertOne(ctx, bill)
if err != nil {
return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err)
}
saved++
}
return saved, duplicates, nil
}
// CheckRawDuplicate 检查原始数据是否重复
func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
filter := bson.M{"raw_data." + fieldName: value}
count, err := r.rawCollection.CountDocuments(ctx, filter)
if err != nil {
return false, err
}
return count > 0, nil
}
// CheckCleanedDuplicate 检查清洗后数据是否重复
func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var filter bson.M
if bill.TransactionID != "" {
// 优先用交易订单号判断
filter = bson.M{"transaction_id": bill.TransactionID}
} else {
// 回退到 时间+金额+商户 组合判断
filter = bson.M{
"time": bill.Time.Time(), // 转换为 time.Time 用于 MongoDB 查询
"amount": bill.Amount,
"merchant": bill.Merchant,
}
}
count, err := r.cleanedCollection.CountDocuments(ctx, filter)
if err != nil {
return false, err
}
return count > 0, nil
}
// CountRawByField 按字段统计原始数据数量
func (r *Repository) CountRawByField(fieldName, value string) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
filter := bson.M{"raw_data." + fieldName: value}
return r.rawCollection.CountDocuments(ctx, filter)
}
// GetCleanedBills 获取清洗后的账单列表
func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 排除已删除的记录
bsonFilter["is_deleted"] = bson.M{"$ne": true}
// 按时间倒序排列
opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}})
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
if err != nil {
return nil, fmt.Errorf("查询账单失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, fmt.Errorf("解析账单数据失败: %w", err)
}
return bills, nil
}
// GetCleanedBillsPaged 获取清洗后的账单列表(带分页)
func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 排除已删除的记录
bsonFilter["is_deleted"] = bson.M{"$ne": true}
// 计算总数
total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter)
if err != nil {
return nil, 0, fmt.Errorf("统计账单数量失败: %w", err)
}
// 计算跳过数量
skip := int64((page - 1) * pageSize)
// 查询选项:分页 + 按时间倒序
opts := options.Find().
SetSort(bson.D{{Key: "time", Value: -1}}).
SetSkip(skip).
SetLimit(int64(pageSize))
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
if err != nil {
return nil, 0, fmt.Errorf("查询账单失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, 0, fmt.Errorf("解析账单数据失败: %w", err)
}
return bills, total, nil
}
// GetBillsAggregate 获取账单聚合统计(总收入、总支出)
func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 排除已删除的记录
bsonFilter["is_deleted"] = bson.M{"$ne": true}
// 使用聚合管道按 income_expense 分组统计金额
pipeline := mongo.Pipeline{
{{Key: "$match", Value: bsonFilter}},
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$income_expense"},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}},
}
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
if err != nil {
return 0, 0, fmt.Errorf("聚合统计失败: %w", err)
}
defer cursor.Close(ctx)
// 解析结果
var results []struct {
ID string `bson:"_id"`
Total float64 `bson:"total"`
}
if err := cursor.All(ctx, &results); err != nil {
return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err)
}
for _, result := range results {
switch result.ID {
case "支出":
totalExpense = result.Total
case "收入":
totalIncome = result.Total
}
}
return totalExpense, totalIncome, nil
}
// GetBillsNeedReview 获取需要复核的账单
func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) {
filter := map[string]interface{}{
"review_level": bson.M{"$in": []string{"HIGH", "LOW"}},
"is_deleted": bson.M{"$ne": true},
}
return r.GetCleanedBills(filter)
}
// GetMonthlyStats 获取月度统计(全部数据,不受筛选条件影响)
func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 使用聚合管道按月份分组统计
// 先按月份和收支类型分组,再汇总
pipeline := mongo.Pipeline{
// 排除已删除的记录
{{Key: "$match", Value: bson.M{"is_deleted": bson.M{"$ne": true}}}},
// 添加月份字段
{{Key: "$addFields", Value: bson.D{
{Key: "month", Value: bson.D{
{Key: "$dateToString", Value: bson.D{
{Key: "format", Value: "%Y-%m"},
{Key: "date", Value: "$time"},
}},
}},
}}},
// 按月份和收支类型分组
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: bson.D{
{Key: "month", Value: "$month"},
{Key: "income_expense", Value: "$income_expense"},
}},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}},
// 按月份重新分组,汇总收入和支出
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$_id.month"},
{Key: "expense", Value: bson.D{
{Key: "$sum", Value: bson.D{
{Key: "$cond", Value: bson.A{
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "支出"}}},
"$total",
0,
}},
}},
}},
{Key: "income", Value: bson.D{
{Key: "$sum", Value: bson.D{
{Key: "$cond", Value: bson.A{
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "收入"}}},
"$total",
0,
}},
}},
}},
}}},
// 按月份排序
{{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
if err != nil {
return nil, fmt.Errorf("月度统计聚合失败: %w", err)
}
defer cursor.Close(ctx)
// 解析结果
var results []struct {
Month string `bson:"_id"`
Expense float64 `bson:"expense"`
Income float64 `bson:"income"`
}
if err := cursor.All(ctx, &results); err != nil {
return nil, fmt.Errorf("解析月度统计结果失败: %w", err)
}
// 转换为 MonthlyStat
stats := make([]model.MonthlyStat, len(results))
for i, r := range results {
stats[i] = model.MonthlyStat{
Month: r.Month,
Expense: r.Expense,
Income: r.Income,
}
}
return stats, nil
}
// UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录
func (r *Repository) UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error) {
if r.cleanedCollection == nil {
return nil, fmt.Errorf("cleaned collection not initialized")
}
oid, err := primitive.ObjectIDFromHex(id)
if err != nil {
return nil, fmt.Errorf("invalid id: %w", err)
}
if len(updates) == 0 {
return nil, fmt.Errorf("no updates")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
filter := bson.M{"_id": oid}
update := bson.M{"$set": updates}
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var updated model.CleanedBill
err = r.cleanedCollection.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updated)
if err != nil {
if err == mongo.ErrNoDocuments {
return nil, repository.ErrNotFound
}
return nil, fmt.Errorf("update bill failed: %w", err)
}
return &updated, nil
}
// DeleteCleanedBillByID 按 ID 软删除清洗后的账单(设置 is_deleted = true
func (r *Repository) DeleteCleanedBillByID(id string) error {
if r.cleanedCollection == nil {
return fmt.Errorf("cleaned collection not initialized")
}
oid, err := primitive.ObjectIDFromHex(id)
if err != nil {
return fmt.Errorf("invalid id: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
filter := bson.M{"_id": oid}
update := bson.M{
"$set": bson.M{
"is_deleted": true,
"updated_at": time.Now(), // 记录更新时间
},
}
result, err := r.cleanedCollection.UpdateOne(ctx, filter, update)
if err != nil {
return fmt.Errorf("soft delete bill failed: %w", err)
}
if result.MatchedCount == 0 {
return repository.ErrNotFound
}
return nil
}
// SoftDeleteJDRelatedBills 软删除描述中包含"京东-订单编号"的非京东账单
// 用于避免京东账单与其他来源(微信、支付宝)账单重复计算
func (r *Repository) SoftDeleteJDRelatedBills() (int64, error) {
if r.cleanedCollection == nil {
return 0, fmt.Errorf("cleaned collection not initialized")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 筛选条件:
// 1. 账单类型不是 jd只处理微信、支付宝等其他来源
// 2. 描述中包含"京东-订单编号"
// 3. 尚未被删除
filter := bson.M{
"bill_type": bson.M{"$ne": "jd"},
"description": bson.M{"$regex": "京东-订单编号", "$options": ""},
"is_deleted": bson.M{"$ne": true},
}
update := bson.M{
"$set": bson.M{
"is_deleted": true,
"updated_at": time.Now(),
},
}
result, err := r.cleanedCollection.UpdateMany(ctx, filter, update)
if err != nil {
return 0, fmt.Errorf("soft delete JD related bills failed: %w", err)
}
return result.ModifiedCount, nil
}
// ReconcileRefund 将跨批次退款核销到已存储的清洗后账单
// 按 bill_type + (transaction_id == orderNo 或 merchant_order_no == merchantOrderNo) 查找未删除记录
// 全额退款(剩余金额 <= refundEpsilon则软删除部分退款则扣减 amount 并追加备注
func (r *Repository) ReconcileRefund(billType, orderNo, merchantOrderNo string, refundAmount float64, refundTime, merchant, description, refundOrderNo string) (bool, error) {
if r.cleanedCollection == nil {
return false, fmt.Errorf("cleaned collection not initialized")
}
if orderNo == "" && merchantOrderNo == "" {
return false, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
var or []bson.M
if orderNo != "" {
or = append(or, bson.M{"transaction_id": orderNo})
}
if merchantOrderNo != "" {
or = append(or, bson.M{"merchant_order_no": merchantOrderNo})
}
filter := bson.M{
"bill_type": billType,
"is_deleted": bson.M{"$ne": true},
"$or": or,
}
var bill model.CleanedBill
if err := r.cleanedCollection.FindOne(ctx, filter).Decode(&bill); err != nil {
if err == mongo.ErrNoDocuments {
return false, nil
}
return false, fmt.Errorf("查询待核销账单失败: %w", err)
}
remaining := bill.Amount - refundAmount
now := time.Now()
if remaining <= refundEpsilon {
update := bson.M{"$set": bson.M{
"is_deleted": true,
"updated_at": now,
"remark": fmt.Sprintf("[退款核销]全额退款%.2f元(退款单号%s);%s", refundAmount, refundOrderNo, bill.Remark),
}}
_, err := r.cleanedCollection.UpdateOne(ctx, bson.M{"_id": bill.ID}, update)
if err != nil {
return false, fmt.Errorf("核销退款失败: %w", err)
}
return true, nil
}
remaining = math.Round(remaining*100) / 100
update := bson.M{"$set": bson.M{
"amount": remaining,
"updated_at": now,
"remark": fmt.Sprintf("原金额%.2f元,退款%.2f元(退款单号%s);%s", bill.Amount, refundAmount, refundOrderNo, bill.Remark),
}}
_, err := r.cleanedCollection.UpdateOne(ctx, bson.M{"_id": bill.ID}, update)
if err != nil {
return false, fmt.Errorf("核销退款失败: %w", err)
}
return true, nil
}
// 建议: 为提升 ReconcileRefund 查询性能,可为 bills_cleaned 添加索引
// {transaction_id:1, bill_type:1} 和 {merchant_order_no:1, bill_type:1}(与现有"无索引"问题一并处理)
// GetClient 获取 MongoDB 客户端(用于兼容旧代码)
func (r *Repository) GetClient() *mongo.Client {
return r.client
}
// GetDB 获取数据库实例(用于兼容旧代码)
func (r *Repository) GetDB() *mongo.Database {
return r.db
}
// GetRawCollection 获取原始数据集合(用于兼容旧代码)
func (r *Repository) GetRawCollection() *mongo.Collection {
return r.rawCollection
}
// GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码)
func (r *Repository) GetCleanedCollection() *mongo.Collection {
return r.cleanedCollection
}
// 确保 Repository 实现了 repository.BillRepository 接口
var _ repository.BillRepository = (*Repository)(nil)