// Package mongo 实现基于 MongoDB 的账单存储 package mongo import ( "context" "fmt" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "billai-server/config" "billai-server/model" "billai-server/repository" ) // Repository MongoDB 账单存储实现 type Repository struct { client *mongo.Client db *mongo.Database rawCollection *mongo.Collection cleanedCollection *mongo.Collection } // NewRepository 创建 MongoDB 存储实例 func NewRepository() *Repository { return &Repository{} } // Connect 连接 MongoDB func (r *Repository) Connect() error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // 创建客户端选项 clientOptions := options.Client().ApplyURI(config.Global.MongoURI) // 连接 MongoDB client, err := mongo.Connect(ctx, clientOptions) if err != nil { return fmt.Errorf("连接 MongoDB 失败: %w", err) } // 测试连接 if err := client.Ping(ctx, nil); err != nil { return fmt.Errorf("MongoDB Ping 失败: %w", err) } // 设置实例变量 r.client = client r.db = client.Database(config.Global.MongoDatabase) r.rawCollection = r.db.Collection(config.Global.MongoRawCollection) r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection) fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase) fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection) fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection) return nil } // Disconnect 断开 MongoDB 连接 func (r *Repository) Disconnect() error { if r.client == nil { return nil } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := r.client.Disconnect(ctx); err != nil { return fmt.Errorf("断开 MongoDB 连接失败: %w", err) } fmt.Println("🍃 MongoDB 连接已断开") return nil } // SaveRawBills 保存原始账单数据 func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) { if len(bills) == 0 { return 0, nil } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // 转换为 interface{} 切片 docs := make([]interface{}, len(bills)) for i, bill := range bills { docs[i] = bill } result, err := r.rawCollection.InsertMany(ctx, docs) if err != nil { return 0, fmt.Errorf("插入原始账单失败: %w", err) } return len(result.InsertedIDs), nil } // SaveCleanedBills 保存清洗后的账单数据(带去重) func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) { if len(bills) == 0 { return 0, 0, nil } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() for _, bill := range bills { // 检查是否重复 isDup, err := r.CheckCleanedDuplicate(&bill) if err != nil { return saved, duplicates, err } if isDup { duplicates++ continue } // 插入新记录 _, err = r.cleanedCollection.InsertOne(ctx, bill) if err != nil { return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err) } saved++ } return saved, duplicates, nil } // CheckRawDuplicate 检查原始数据是否重复 func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() filter := bson.M{"raw_data." + fieldName: value} count, err := r.rawCollection.CountDocuments(ctx, filter) if err != nil { return false, err } return count > 0, nil } // CheckCleanedDuplicate 检查清洗后数据是否重复 func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var filter bson.M if bill.TransactionID != "" { // 优先用交易订单号判断 filter = bson.M{"transaction_id": bill.TransactionID} } else { // 回退到 时间+金额+商户 组合判断 filter = bson.M{ "time": bill.Time.Time(), // 转换为 time.Time 用于 MongoDB 查询 "amount": bill.Amount, "merchant": bill.Merchant, } } count, err := r.cleanedCollection.CountDocuments(ctx, filter) if err != nil { return false, err } return count > 0, nil } // CountRawByField 按字段统计原始数据数量 func (r *Repository) CountRawByField(fieldName, value string) (int64, error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() filter := bson.M{"raw_data." + fieldName: value} return r.rawCollection.CountDocuments(ctx, filter) } // GetCleanedBills 获取清洗后的账单列表 func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // 转换 filter bsonFilter := bson.M{} for k, v := range filter { bsonFilter[k] = v } // 按时间倒序排列 opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}}) cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts) if err != nil { return nil, fmt.Errorf("查询账单失败: %w", err) } defer cursor.Close(ctx) var bills []model.CleanedBill if err := cursor.All(ctx, &bills); err != nil { return nil, fmt.Errorf("解析账单数据失败: %w", err) } return bills, nil } // GetCleanedBillsPaged 获取清洗后的账单列表(带分页) func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // 转换 filter bsonFilter := bson.M{} for k, v := range filter { bsonFilter[k] = v } // 计算总数 total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter) if err != nil { return nil, 0, fmt.Errorf("统计账单数量失败: %w", err) } // 计算跳过数量 skip := int64((page - 1) * pageSize) // 查询选项:分页 + 按时间倒序 opts := options.Find(). SetSort(bson.D{{Key: "time", Value: -1}}). SetSkip(skip). SetLimit(int64(pageSize)) cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts) if err != nil { return nil, 0, fmt.Errorf("查询账单失败: %w", err) } defer cursor.Close(ctx) var bills []model.CleanedBill if err := cursor.All(ctx, &bills); err != nil { return nil, 0, fmt.Errorf("解析账单数据失败: %w", err) } return bills, total, nil } // GetBillsAggregate 获取账单聚合统计(总收入、总支出) func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // 转换 filter bsonFilter := bson.M{} for k, v := range filter { bsonFilter[k] = v } // 使用聚合管道按 income_expense 分组统计金额 pipeline := mongo.Pipeline{ {{Key: "$match", Value: bsonFilter}}, {{Key: "$group", Value: bson.D{ {Key: "_id", Value: "$income_expense"}, {Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}}, }}}, } cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline) if err != nil { return 0, 0, fmt.Errorf("聚合统计失败: %w", err) } defer cursor.Close(ctx) // 解析结果 var results []struct { ID string `bson:"_id"` Total float64 `bson:"total"` } if err := cursor.All(ctx, &results); err != nil { return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err) } for _, result := range results { switch result.ID { case "支出": totalExpense = result.Total case "收入": totalIncome = result.Total } } return totalExpense, totalIncome, nil } // GetBillsNeedReview 获取需要复核的账单 func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) { filter := map[string]interface{}{ "review_level": bson.M{"$in": []string{"HIGH", "LOW"}}, } return r.GetCleanedBills(filter) } // GetMonthlyStats 获取月度统计(全部数据,不受筛选条件影响) func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() // 使用聚合管道按月份分组统计 // 先按月份和收支类型分组,再汇总 pipeline := mongo.Pipeline{ // 添加月份字段 {{Key: "$addFields", Value: bson.D{ {Key: "month", Value: bson.D{ {Key: "$dateToString", Value: bson.D{ {Key: "format", Value: "%Y-%m"}, {Key: "date", Value: "$time"}, }}, }}, }}}, // 按月份和收支类型分组 {{Key: "$group", Value: bson.D{ {Key: "_id", Value: bson.D{ {Key: "month", Value: "$month"}, {Key: "income_expense", Value: "$income_expense"}, }}, {Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}}, }}}, // 按月份重新分组,汇总收入和支出 {{Key: "$group", Value: bson.D{ {Key: "_id", Value: "$_id.month"}, {Key: "expense", Value: bson.D{ {Key: "$sum", Value: bson.D{ {Key: "$cond", Value: bson.A{ bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "支出"}}}, "$total", 0, }}, }}, }}, {Key: "income", Value: bson.D{ {Key: "$sum", Value: bson.D{ {Key: "$cond", Value: bson.A{ bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "收入"}}}, "$total", 0, }}, }}, }}, }}}, // 按月份排序 {{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}}, } cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline) if err != nil { return nil, fmt.Errorf("月度统计聚合失败: %w", err) } defer cursor.Close(ctx) // 解析结果 var results []struct { Month string `bson:"_id"` Expense float64 `bson:"expense"` Income float64 `bson:"income"` } if err := cursor.All(ctx, &results); err != nil { return nil, fmt.Errorf("解析月度统计结果失败: %w", err) } // 转换为 MonthlyStat stats := make([]model.MonthlyStat, len(results)) for i, r := range results { stats[i] = model.MonthlyStat{ Month: r.Month, Expense: r.Expense, Income: r.Income, } } return stats, nil } // GetClient 获取 MongoDB 客户端(用于兼容旧代码) func (r *Repository) GetClient() *mongo.Client { return r.client } // GetDB 获取数据库实例(用于兼容旧代码) func (r *Repository) GetDB() *mongo.Database { return r.db } // GetRawCollection 获取原始数据集合(用于兼容旧代码) func (r *Repository) GetRawCollection() *mongo.Collection { return r.rawCollection } // GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码) func (r *Repository) GetCleanedCollection() *mongo.Collection { return r.cleanedCollection } // 确保 Repository 实现了 repository.BillRepository 接口 var _ repository.BillRepository = (*Repository)(nil)