Files
billai/server/repository/mongo/repository.go

518 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package mongo 实现基于 MongoDB 的账单存储
package mongo
import (
"context"
"fmt"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"billai-server/config"
"billai-server/model"
"billai-server/repository"
)
// Repository MongoDB 账单存储实现
type Repository struct {
client *mongo.Client
db *mongo.Database
rawCollection *mongo.Collection
cleanedCollection *mongo.Collection
}
// NewRepository 创建 MongoDB 存储实例
func NewRepository() *Repository {
return &Repository{}
}
// Connect 连接 MongoDB
func (r *Repository) Connect() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 创建客户端选项
clientOptions := options.Client().ApplyURI(config.Global.MongoURI)
// 连接 MongoDB
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return fmt.Errorf("连接 MongoDB 失败: %w", err)
}
// 测试连接
if err := client.Ping(ctx, nil); err != nil {
return fmt.Errorf("MongoDB Ping 失败: %w", err)
}
// 设置实例变量
r.client = client
r.db = client.Database(config.Global.MongoDatabase)
r.rawCollection = r.db.Collection(config.Global.MongoRawCollection)
r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection)
fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase)
fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection)
fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection)
return nil
}
// Disconnect 断开 MongoDB 连接
func (r *Repository) Disconnect() error {
if r.client == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := r.client.Disconnect(ctx); err != nil {
return fmt.Errorf("断开 MongoDB 连接失败: %w", err)
}
fmt.Println("🍃 MongoDB 连接已断开")
return nil
}
// SaveRawBills 保存原始账单数据
func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) {
if len(bills) == 0 {
return 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换为 interface{} 切片
docs := make([]interface{}, len(bills))
for i, bill := range bills {
docs[i] = bill
}
result, err := r.rawCollection.InsertMany(ctx, docs)
if err != nil {
return 0, fmt.Errorf("插入原始账单失败: %w", err)
}
return len(result.InsertedIDs), nil
}
// SaveCleanedBills 保存清洗后的账单数据(带去重)
func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) {
if len(bills) == 0 {
return 0, 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, bill := range bills {
// 检查是否重复
isDup, err := r.CheckCleanedDuplicate(&bill)
if err != nil {
return saved, duplicates, err
}
if isDup {
duplicates++
continue
}
// 插入新记录
_, err = r.cleanedCollection.InsertOne(ctx, bill)
if err != nil {
return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err)
}
saved++
}
return saved, duplicates, nil
}
// CheckRawDuplicate 检查原始数据是否重复
func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
filter := bson.M{"raw_data." + fieldName: value}
count, err := r.rawCollection.CountDocuments(ctx, filter)
if err != nil {
return false, err
}
return count > 0, nil
}
// CheckCleanedDuplicate 检查清洗后数据是否重复
func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var filter bson.M
if bill.TransactionID != "" {
// 优先用交易订单号判断
filter = bson.M{"transaction_id": bill.TransactionID}
} else {
// 回退到 时间+金额+商户 组合判断
filter = bson.M{
"time": bill.Time.Time(), // 转换为 time.Time 用于 MongoDB 查询
"amount": bill.Amount,
"merchant": bill.Merchant,
}
}
count, err := r.cleanedCollection.CountDocuments(ctx, filter)
if err != nil {
return false, err
}
return count > 0, nil
}
// CountRawByField 按字段统计原始数据数量
func (r *Repository) CountRawByField(fieldName, value string) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
filter := bson.M{"raw_data." + fieldName: value}
return r.rawCollection.CountDocuments(ctx, filter)
}
// GetCleanedBills 获取清洗后的账单列表
func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 排除已删除的记录
bsonFilter["is_deleted"] = bson.M{"$ne": true}
// 按时间倒序排列
opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}})
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
if err != nil {
return nil, fmt.Errorf("查询账单失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, fmt.Errorf("解析账单数据失败: %w", err)
}
return bills, nil
}
// GetCleanedBillsPaged 获取清洗后的账单列表(带分页)
func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 排除已删除的记录
bsonFilter["is_deleted"] = bson.M{"$ne": true}
// 计算总数
total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter)
if err != nil {
return nil, 0, fmt.Errorf("统计账单数量失败: %w", err)
}
// 计算跳过数量
skip := int64((page - 1) * pageSize)
// 查询选项:分页 + 按时间倒序
opts := options.Find().
SetSort(bson.D{{Key: "time", Value: -1}}).
SetSkip(skip).
SetLimit(int64(pageSize))
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
if err != nil {
return nil, 0, fmt.Errorf("查询账单失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, 0, fmt.Errorf("解析账单数据失败: %w", err)
}
return bills, total, nil
}
// GetBillsAggregate 获取账单聚合统计(总收入、总支出)
func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 排除已删除的记录
bsonFilter["is_deleted"] = bson.M{"$ne": true}
// 使用聚合管道按 income_expense 分组统计金额
pipeline := mongo.Pipeline{
{{Key: "$match", Value: bsonFilter}},
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$income_expense"},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}},
}
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
if err != nil {
return 0, 0, fmt.Errorf("聚合统计失败: %w", err)
}
defer cursor.Close(ctx)
// 解析结果
var results []struct {
ID string `bson:"_id"`
Total float64 `bson:"total"`
}
if err := cursor.All(ctx, &results); err != nil {
return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err)
}
for _, result := range results {
switch result.ID {
case "支出":
totalExpense = result.Total
case "收入":
totalIncome = result.Total
}
}
return totalExpense, totalIncome, nil
}
// GetBillsNeedReview 获取需要复核的账单
func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) {
filter := map[string]interface{}{
"review_level": bson.M{"$in": []string{"HIGH", "LOW"}},
"is_deleted": bson.M{"$ne": true},
}
return r.GetCleanedBills(filter)
}
// GetMonthlyStats 获取月度统计(全部数据,不受筛选条件影响)
func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 使用聚合管道按月份分组统计
// 先按月份和收支类型分组,再汇总
pipeline := mongo.Pipeline{
// 排除已删除的记录
{{Key: "$match", Value: bson.M{"is_deleted": bson.M{"$ne": true}}}},
// 添加月份字段
{{Key: "$addFields", Value: bson.D{
{Key: "month", Value: bson.D{
{Key: "$dateToString", Value: bson.D{
{Key: "format", Value: "%Y-%m"},
{Key: "date", Value: "$time"},
}},
}},
}}},
// 按月份和收支类型分组
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: bson.D{
{Key: "month", Value: "$month"},
{Key: "income_expense", Value: "$income_expense"},
}},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}},
// 按月份重新分组,汇总收入和支出
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$_id.month"},
{Key: "expense", Value: bson.D{
{Key: "$sum", Value: bson.D{
{Key: "$cond", Value: bson.A{
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "支出"}}},
"$total",
0,
}},
}},
}},
{Key: "income", Value: bson.D{
{Key: "$sum", Value: bson.D{
{Key: "$cond", Value: bson.A{
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "收入"}}},
"$total",
0,
}},
}},
}},
}}},
// 按月份排序
{{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
if err != nil {
return nil, fmt.Errorf("月度统计聚合失败: %w", err)
}
defer cursor.Close(ctx)
// 解析结果
var results []struct {
Month string `bson:"_id"`
Expense float64 `bson:"expense"`
Income float64 `bson:"income"`
}
if err := cursor.All(ctx, &results); err != nil {
return nil, fmt.Errorf("解析月度统计结果失败: %w", err)
}
// 转换为 MonthlyStat
stats := make([]model.MonthlyStat, len(results))
for i, r := range results {
stats[i] = model.MonthlyStat{
Month: r.Month,
Expense: r.Expense,
Income: r.Income,
}
}
return stats, nil
}
// UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录
func (r *Repository) UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error) {
if r.cleanedCollection == nil {
return nil, fmt.Errorf("cleaned collection not initialized")
}
oid, err := primitive.ObjectIDFromHex(id)
if err != nil {
return nil, fmt.Errorf("invalid id: %w", err)
}
if len(updates) == 0 {
return nil, fmt.Errorf("no updates")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
filter := bson.M{"_id": oid}
update := bson.M{"$set": updates}
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var updated model.CleanedBill
err = r.cleanedCollection.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updated)
if err != nil {
if err == mongo.ErrNoDocuments {
return nil, repository.ErrNotFound
}
return nil, fmt.Errorf("update bill failed: %w", err)
}
return &updated, nil
}
// DeleteCleanedBillByID 按 ID 软删除清洗后的账单(设置 is_deleted = true
func (r *Repository) DeleteCleanedBillByID(id string) error {
if r.cleanedCollection == nil {
return fmt.Errorf("cleaned collection not initialized")
}
oid, err := primitive.ObjectIDFromHex(id)
if err != nil {
return fmt.Errorf("invalid id: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
filter := bson.M{"_id": oid}
update := bson.M{"$set": bson.M{"is_deleted": true}}
result, err := r.cleanedCollection.UpdateOne(ctx, filter, update)
if err != nil {
return fmt.Errorf("soft delete bill failed: %w", err)
}
if result.MatchedCount == 0 {
return repository.ErrNotFound
}
return nil
}
// SoftDeleteJDRelatedBills 软删除描述中包含"京东-订单编号"的非京东账单
// 用于避免京东账单与其他来源(微信、支付宝)账单重复计算
func (r *Repository) SoftDeleteJDRelatedBills() (int64, error) {
if r.cleanedCollection == nil {
return 0, fmt.Errorf("cleaned collection not initialized")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 筛选条件:
// 1. 账单类型不是 jd只处理微信、支付宝等其他来源
// 2. 描述中包含"京东-订单编号"
// 3. 尚未被删除
filter := bson.M{
"bill_type": bson.M{"$ne": "jd"},
"description": bson.M{"$regex": "京东-订单编号", "$options": ""},
"is_deleted": bson.M{"$ne": true},
}
update := bson.M{
"$set": bson.M{
"is_deleted": true,
"updated_at": time.Now(),
},
}
result, err := r.cleanedCollection.UpdateMany(ctx, filter, update)
if err != nil {
return 0, fmt.Errorf("soft delete JD related bills failed: %w", err)
}
return result.ModifiedCount, nil
}
// GetClient 获取 MongoDB 客户端(用于兼容旧代码)
func (r *Repository) GetClient() *mongo.Client {
return r.client
}
// GetDB 获取数据库实例(用于兼容旧代码)
func (r *Repository) GetDB() *mongo.Database {
return r.db
}
// GetRawCollection 获取原始数据集合(用于兼容旧代码)
func (r *Repository) GetRawCollection() *mongo.Collection {
return r.rawCollection
}
// GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码)
func (r *Repository) GetCleanedCollection() *mongo.Collection {
return r.cleanedCollection
}
// 确保 Repository 实现了 repository.BillRepository 接口
var _ repository.BillRepository = (*Repository)(nil)