diff --git a/.gitignore b/.gitignore index 70e7625..6ddf0c0 100644 --- a/.gitignore +++ b/.gitignore @@ -20,5 +20,3 @@ server/billai-server server/uploads/ server/outputs/ *.log - -mongo/ \ No newline at end of file diff --git a/server/repository/mongo/repository.go b/server/repository/mongo/repository.go new file mode 100644 index 0000000..0ec82da --- /dev/null +++ b/server/repository/mongo/repository.go @@ -0,0 +1,327 @@ +// Package mongo 实现基于 MongoDB 的账单存储 +package mongo + +import ( + "context" + "fmt" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + + "billai-server/config" + "billai-server/model" + "billai-server/repository" +) + +// Repository MongoDB 账单存储实现 +type Repository struct { + client *mongo.Client + db *mongo.Database + rawCollection *mongo.Collection + cleanedCollection *mongo.Collection +} + +// NewRepository 创建 MongoDB 存储实例 +func NewRepository() *Repository { + return &Repository{} +} + +// Connect 连接 MongoDB +func (r *Repository) Connect() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // 创建客户端选项 + clientOptions := options.Client().ApplyURI(config.Global.MongoURI) + + // 连接 MongoDB + client, err := mongo.Connect(ctx, clientOptions) + if err != nil { + return fmt.Errorf("连接 MongoDB 失败: %w", err) + } + + // 测试连接 + if err := client.Ping(ctx, nil); err != nil { + return fmt.Errorf("MongoDB Ping 失败: %w", err) + } + + // 设置实例变量 + r.client = client + r.db = client.Database(config.Global.MongoDatabase) + r.rawCollection = r.db.Collection(config.Global.MongoRawCollection) + r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection) + + fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase) + fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection) + fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection) + return nil +} + +// Disconnect 断开 MongoDB 连接 +func (r *Repository) Disconnect() error { + if r.client == nil { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := r.client.Disconnect(ctx); err != nil { + return fmt.Errorf("断开 MongoDB 连接失败: %w", err) + } + + fmt.Println("🍃 MongoDB 连接已断开") + return nil +} + +// SaveRawBills 保存原始账单数据 +func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) { + if len(bills) == 0 { + return 0, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 转换为 interface{} 切片 + docs := make([]interface{}, len(bills)) + for i, bill := range bills { + docs[i] = bill + } + + result, err := r.rawCollection.InsertMany(ctx, docs) + if err != nil { + return 0, fmt.Errorf("插入原始账单失败: %w", err) + } + + return len(result.InsertedIDs), nil +} + +// SaveCleanedBills 保存清洗后的账单数据(带去重) +func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) { + if len(bills) == 0 { + return 0, 0, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + for _, bill := range bills { + // 检查是否重复 + isDup, err := r.CheckCleanedDuplicate(&bill) + if err != nil { + return saved, duplicates, err + } + if isDup { + duplicates++ + continue + } + + // 插入新记录 + _, err = r.cleanedCollection.InsertOne(ctx, bill) + if err != nil { + return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err) + } + saved++ + } + + return saved, duplicates, nil +} + +// CheckRawDuplicate 检查原始数据是否重复 +func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + filter := bson.M{"raw_data." + fieldName: value} + count, err := r.rawCollection.CountDocuments(ctx, filter) + if err != nil { + return false, err + } + + return count > 0, nil +} + +// CheckCleanedDuplicate 检查清洗后数据是否重复 +func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var filter bson.M + + if bill.TransactionID != "" { + // 优先用交易订单号判断 + filter = bson.M{"transaction_id": bill.TransactionID} + } else { + // 回退到 时间+金额+商户 组合判断 + filter = bson.M{ + "time": bill.Time, + "amount": bill.Amount, + "merchant": bill.Merchant, + } + } + + count, err := r.cleanedCollection.CountDocuments(ctx, filter) + if err != nil { + return false, err + } + + return count > 0, nil +} + +// CountRawByField 按字段统计原始数据数量 +func (r *Repository) CountRawByField(fieldName, value string) (int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + filter := bson.M{"raw_data." + fieldName: value} + return r.rawCollection.CountDocuments(ctx, filter) +} + +// GetCleanedBills 获取清洗后的账单列表 +func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 转换 filter + bsonFilter := bson.M{} + for k, v := range filter { + bsonFilter[k] = v + } + + // 按时间倒序排列 + opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}}) + + cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts) + if err != nil { + return nil, fmt.Errorf("查询账单失败: %w", err) + } + defer cursor.Close(ctx) + + var bills []model.CleanedBill + if err := cursor.All(ctx, &bills); err != nil { + return nil, fmt.Errorf("解析账单数据失败: %w", err) + } + + return bills, nil +} + +// GetCleanedBillsPaged 获取清洗后的账单列表(带分页) +func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 转换 filter + bsonFilter := bson.M{} + for k, v := range filter { + bsonFilter[k] = v + } + + // 计算总数 + total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter) + if err != nil { + return nil, 0, fmt.Errorf("统计账单数量失败: %w", err) + } + + // 计算跳过数量 + skip := int64((page - 1) * pageSize) + + // 查询选项:分页 + 按时间倒序 + opts := options.Find(). + SetSort(bson.D{{Key: "time", Value: -1}}). + SetSkip(skip). + SetLimit(int64(pageSize)) + + cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts) + if err != nil { + return nil, 0, fmt.Errorf("查询账单失败: %w", err) + } + defer cursor.Close(ctx) + + var bills []model.CleanedBill + if err := cursor.All(ctx, &bills); err != nil { + return nil, 0, fmt.Errorf("解析账单数据失败: %w", err) + } + + return bills, total, nil +} + +// GetBillsAggregate 获取账单聚合统计(总收入、总支出) +func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 转换 filter + bsonFilter := bson.M{} + for k, v := range filter { + bsonFilter[k] = v + } + + // 使用聚合管道按 income_expense 分组统计金额 + pipeline := mongo.Pipeline{ + {{Key: "$match", Value: bsonFilter}}, + {{Key: "$group", Value: bson.D{ + {Key: "_id", Value: "$income_expense"}, + {Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}}, + }}}, + } + + cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline) + if err != nil { + return 0, 0, fmt.Errorf("聚合统计失败: %w", err) + } + defer cursor.Close(ctx) + + // 解析结果 + var results []struct { + ID string `bson:"_id"` + Total float64 `bson:"total"` + } + if err := cursor.All(ctx, &results); err != nil { + return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err) + } + + for _, result := range results { + switch result.ID { + case "支出": + totalExpense = result.Total + case "收入": + totalIncome = result.Total + } + } + + return totalExpense, totalIncome, nil +} + +// GetBillsNeedReview 获取需要复核的账单 +func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) { + filter := map[string]interface{}{ + "review_level": bson.M{"$in": []string{"HIGH", "LOW"}}, + } + return r.GetCleanedBills(filter) +} + +// GetClient 获取 MongoDB 客户端(用于兼容旧代码) +func (r *Repository) GetClient() *mongo.Client { + return r.client +} + +// GetDB 获取数据库实例(用于兼容旧代码) +func (r *Repository) GetDB() *mongo.Database { + return r.db +} + +// GetRawCollection 获取原始数据集合(用于兼容旧代码) +func (r *Repository) GetRawCollection() *mongo.Collection { + return r.rawCollection +} + +// GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码) +func (r *Repository) GetCleanedCollection() *mongo.Collection { + return r.cleanedCollection +} + +// 确保 Repository 实现了 repository.BillRepository 接口 +var _ repository.BillRepository = (*Repository)(nil)