518 lines
14 KiB
Go
518 lines
14 KiB
Go
// Package mongo 实现基于 MongoDB 的账单存储
|
||
package mongo
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"go.mongodb.org/mongo-driver/bson"
|
||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
"go.mongodb.org/mongo-driver/mongo"
|
||
"go.mongodb.org/mongo-driver/mongo/options"
|
||
|
||
"billai-server/config"
|
||
"billai-server/model"
|
||
"billai-server/repository"
|
||
)
|
||
|
||
// Repository MongoDB 账单存储实现
|
||
type Repository struct {
|
||
client *mongo.Client
|
||
db *mongo.Database
|
||
rawCollection *mongo.Collection
|
||
cleanedCollection *mongo.Collection
|
||
}
|
||
|
||
// NewRepository 创建 MongoDB 存储实例
|
||
func NewRepository() *Repository {
|
||
return &Repository{}
|
||
}
|
||
|
||
// Connect 连接 MongoDB
|
||
func (r *Repository) Connect() error {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
|
||
// 创建客户端选项
|
||
clientOptions := options.Client().ApplyURI(config.Global.MongoURI)
|
||
|
||
// 连接 MongoDB
|
||
client, err := mongo.Connect(ctx, clientOptions)
|
||
if err != nil {
|
||
return fmt.Errorf("连接 MongoDB 失败: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := client.Ping(ctx, nil); err != nil {
|
||
return fmt.Errorf("MongoDB Ping 失败: %w", err)
|
||
}
|
||
|
||
// 设置实例变量
|
||
r.client = client
|
||
r.db = client.Database(config.Global.MongoDatabase)
|
||
r.rawCollection = r.db.Collection(config.Global.MongoRawCollection)
|
||
r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection)
|
||
|
||
fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase)
|
||
fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection)
|
||
fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection)
|
||
return nil
|
||
}
|
||
|
||
// Disconnect 断开 MongoDB 连接
|
||
func (r *Repository) Disconnect() error {
|
||
if r.client == nil {
|
||
return nil
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
|
||
if err := r.client.Disconnect(ctx); err != nil {
|
||
return fmt.Errorf("断开 MongoDB 连接失败: %w", err)
|
||
}
|
||
|
||
fmt.Println("🍃 MongoDB 连接已断开")
|
||
return nil
|
||
}
|
||
|
||
// SaveRawBills 保存原始账单数据
|
||
func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) {
|
||
if len(bills) == 0 {
|
||
return 0, nil
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// 转换为 interface{} 切片
|
||
docs := make([]interface{}, len(bills))
|
||
for i, bill := range bills {
|
||
docs[i] = bill
|
||
}
|
||
|
||
result, err := r.rawCollection.InsertMany(ctx, docs)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("插入原始账单失败: %w", err)
|
||
}
|
||
|
||
return len(result.InsertedIDs), nil
|
||
}
|
||
|
||
// SaveCleanedBills 保存清洗后的账单数据(带去重)
|
||
func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) {
|
||
if len(bills) == 0 {
|
||
return 0, 0, nil
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
for _, bill := range bills {
|
||
// 检查是否重复
|
||
isDup, err := r.CheckCleanedDuplicate(&bill)
|
||
if err != nil {
|
||
return saved, duplicates, err
|
||
}
|
||
if isDup {
|
||
duplicates++
|
||
continue
|
||
}
|
||
|
||
// 插入新记录
|
||
_, err = r.cleanedCollection.InsertOne(ctx, bill)
|
||
if err != nil {
|
||
return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err)
|
||
}
|
||
saved++
|
||
}
|
||
|
||
return saved, duplicates, nil
|
||
}
|
||
|
||
// CheckRawDuplicate 检查原始数据是否重复
|
||
func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
filter := bson.M{"raw_data." + fieldName: value}
|
||
count, err := r.rawCollection.CountDocuments(ctx, filter)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return count > 0, nil
|
||
}
|
||
|
||
// CheckCleanedDuplicate 检查清洗后数据是否重复
|
||
func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
var filter bson.M
|
||
|
||
if bill.TransactionID != "" {
|
||
// 优先用交易订单号判断
|
||
filter = bson.M{"transaction_id": bill.TransactionID}
|
||
} else {
|
||
// 回退到 时间+金额+商户 组合判断
|
||
filter = bson.M{
|
||
"time": bill.Time.Time(), // 转换为 time.Time 用于 MongoDB 查询
|
||
"amount": bill.Amount,
|
||
"merchant": bill.Merchant,
|
||
}
|
||
}
|
||
|
||
count, err := r.cleanedCollection.CountDocuments(ctx, filter)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return count > 0, nil
|
||
}
|
||
|
||
// CountRawByField 按字段统计原始数据数量
|
||
func (r *Repository) CountRawByField(fieldName, value string) (int64, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
filter := bson.M{"raw_data." + fieldName: value}
|
||
return r.rawCollection.CountDocuments(ctx, filter)
|
||
}
|
||
|
||
// GetCleanedBills 获取清洗后的账单列表
|
||
func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// 转换 filter
|
||
bsonFilter := bson.M{}
|
||
for k, v := range filter {
|
||
bsonFilter[k] = v
|
||
}
|
||
|
||
// 排除已删除的记录
|
||
bsonFilter["is_deleted"] = bson.M{"$ne": true}
|
||
|
||
// 按时间倒序排列
|
||
opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}})
|
||
|
||
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("查询账单失败: %w", err)
|
||
}
|
||
defer cursor.Close(ctx)
|
||
|
||
var bills []model.CleanedBill
|
||
if err := cursor.All(ctx, &bills); err != nil {
|
||
return nil, fmt.Errorf("解析账单数据失败: %w", err)
|
||
}
|
||
|
||
return bills, nil
|
||
}
|
||
|
||
// GetCleanedBillsPaged 获取清洗后的账单列表(带分页)
|
||
func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// 转换 filter
|
||
bsonFilter := bson.M{}
|
||
for k, v := range filter {
|
||
bsonFilter[k] = v
|
||
}
|
||
|
||
// 排除已删除的记录
|
||
bsonFilter["is_deleted"] = bson.M{"$ne": true}
|
||
|
||
// 计算总数
|
||
total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("统计账单数量失败: %w", err)
|
||
}
|
||
|
||
// 计算跳过数量
|
||
skip := int64((page - 1) * pageSize)
|
||
|
||
// 查询选项:分页 + 按时间倒序
|
||
opts := options.Find().
|
||
SetSort(bson.D{{Key: "time", Value: -1}}).
|
||
SetSkip(skip).
|
||
SetLimit(int64(pageSize))
|
||
|
||
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
|
||
if err != nil {
|
||
return nil, 0, fmt.Errorf("查询账单失败: %w", err)
|
||
}
|
||
defer cursor.Close(ctx)
|
||
|
||
var bills []model.CleanedBill
|
||
if err := cursor.All(ctx, &bills); err != nil {
|
||
return nil, 0, fmt.Errorf("解析账单数据失败: %w", err)
|
||
}
|
||
|
||
return bills, total, nil
|
||
}
|
||
|
||
// GetBillsAggregate 获取账单聚合统计(总收入、总支出)
|
||
func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// 转换 filter
|
||
bsonFilter := bson.M{}
|
||
for k, v := range filter {
|
||
bsonFilter[k] = v
|
||
}
|
||
|
||
// 排除已删除的记录
|
||
bsonFilter["is_deleted"] = bson.M{"$ne": true}
|
||
|
||
// 使用聚合管道按 income_expense 分组统计金额
|
||
pipeline := mongo.Pipeline{
|
||
{{Key: "$match", Value: bsonFilter}},
|
||
{{Key: "$group", Value: bson.D{
|
||
{Key: "_id", Value: "$income_expense"},
|
||
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
|
||
}}},
|
||
}
|
||
|
||
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
|
||
if err != nil {
|
||
return 0, 0, fmt.Errorf("聚合统计失败: %w", err)
|
||
}
|
||
defer cursor.Close(ctx)
|
||
|
||
// 解析结果
|
||
var results []struct {
|
||
ID string `bson:"_id"`
|
||
Total float64 `bson:"total"`
|
||
}
|
||
if err := cursor.All(ctx, &results); err != nil {
|
||
return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err)
|
||
}
|
||
|
||
for _, result := range results {
|
||
switch result.ID {
|
||
case "支出":
|
||
totalExpense = result.Total
|
||
case "收入":
|
||
totalIncome = result.Total
|
||
}
|
||
}
|
||
|
||
return totalExpense, totalIncome, nil
|
||
}
|
||
|
||
// GetBillsNeedReview 获取需要复核的账单
|
||
func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) {
|
||
filter := map[string]interface{}{
|
||
"review_level": bson.M{"$in": []string{"HIGH", "LOW"}},
|
||
"is_deleted": bson.M{"$ne": true},
|
||
}
|
||
return r.GetCleanedBills(filter)
|
||
}
|
||
|
||
// GetMonthlyStats 获取月度统计(全部数据,不受筛选条件影响)
|
||
func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// 使用聚合管道按月份分组统计
|
||
// 先按月份和收支类型分组,再汇总
|
||
pipeline := mongo.Pipeline{
|
||
// 排除已删除的记录
|
||
{{Key: "$match", Value: bson.M{"is_deleted": bson.M{"$ne": true}}}},
|
||
// 添加月份字段
|
||
{{Key: "$addFields", Value: bson.D{
|
||
{Key: "month", Value: bson.D{
|
||
{Key: "$dateToString", Value: bson.D{
|
||
{Key: "format", Value: "%Y-%m"},
|
||
{Key: "date", Value: "$time"},
|
||
}},
|
||
}},
|
||
}}},
|
||
// 按月份和收支类型分组
|
||
{{Key: "$group", Value: bson.D{
|
||
{Key: "_id", Value: bson.D{
|
||
{Key: "month", Value: "$month"},
|
||
{Key: "income_expense", Value: "$income_expense"},
|
||
}},
|
||
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
|
||
}}},
|
||
// 按月份重新分组,汇总收入和支出
|
||
{{Key: "$group", Value: bson.D{
|
||
{Key: "_id", Value: "$_id.month"},
|
||
{Key: "expense", Value: bson.D{
|
||
{Key: "$sum", Value: bson.D{
|
||
{Key: "$cond", Value: bson.A{
|
||
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "支出"}}},
|
||
"$total",
|
||
0,
|
||
}},
|
||
}},
|
||
}},
|
||
{Key: "income", Value: bson.D{
|
||
{Key: "$sum", Value: bson.D{
|
||
{Key: "$cond", Value: bson.A{
|
||
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "收入"}}},
|
||
"$total",
|
||
0,
|
||
}},
|
||
}},
|
||
}},
|
||
}}},
|
||
// 按月份排序
|
||
{{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}},
|
||
}
|
||
|
||
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("月度统计聚合失败: %w", err)
|
||
}
|
||
defer cursor.Close(ctx)
|
||
|
||
// 解析结果
|
||
var results []struct {
|
||
Month string `bson:"_id"`
|
||
Expense float64 `bson:"expense"`
|
||
Income float64 `bson:"income"`
|
||
}
|
||
if err := cursor.All(ctx, &results); err != nil {
|
||
return nil, fmt.Errorf("解析月度统计结果失败: %w", err)
|
||
}
|
||
|
||
// 转换为 MonthlyStat
|
||
stats := make([]model.MonthlyStat, len(results))
|
||
for i, r := range results {
|
||
stats[i] = model.MonthlyStat{
|
||
Month: r.Month,
|
||
Expense: r.Expense,
|
||
Income: r.Income,
|
||
}
|
||
}
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录
|
||
func (r *Repository) UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error) {
|
||
if r.cleanedCollection == nil {
|
||
return nil, fmt.Errorf("cleaned collection not initialized")
|
||
}
|
||
|
||
oid, err := primitive.ObjectIDFromHex(id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("invalid id: %w", err)
|
||
}
|
||
|
||
if len(updates) == 0 {
|
||
return nil, fmt.Errorf("no updates")
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
|
||
filter := bson.M{"_id": oid}
|
||
update := bson.M{"$set": updates}
|
||
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
|
||
|
||
var updated model.CleanedBill
|
||
err = r.cleanedCollection.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updated)
|
||
if err != nil {
|
||
if err == mongo.ErrNoDocuments {
|
||
return nil, repository.ErrNotFound
|
||
}
|
||
return nil, fmt.Errorf("update bill failed: %w", err)
|
||
}
|
||
|
||
return &updated, nil
|
||
}
|
||
|
||
// DeleteCleanedBillByID 按 ID 软删除清洗后的账单(设置 is_deleted = true)
|
||
func (r *Repository) DeleteCleanedBillByID(id string) error {
|
||
if r.cleanedCollection == nil {
|
||
return fmt.Errorf("cleaned collection not initialized")
|
||
}
|
||
|
||
oid, err := primitive.ObjectIDFromHex(id)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid id: %w", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
|
||
filter := bson.M{"_id": oid}
|
||
update := bson.M{"$set": bson.M{"is_deleted": true}}
|
||
result, err := r.cleanedCollection.UpdateOne(ctx, filter, update)
|
||
if err != nil {
|
||
return fmt.Errorf("soft delete bill failed: %w", err)
|
||
}
|
||
|
||
if result.MatchedCount == 0 {
|
||
return repository.ErrNotFound
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// SoftDeleteJDRelatedBills 软删除描述中包含"京东-订单编号"的非京东账单
|
||
// 用于避免京东账单与其他来源(微信、支付宝)账单重复计算
|
||
func (r *Repository) SoftDeleteJDRelatedBills() (int64, error) {
|
||
if r.cleanedCollection == nil {
|
||
return 0, fmt.Errorf("cleaned collection not initialized")
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||
defer cancel()
|
||
|
||
// 筛选条件:
|
||
// 1. 账单类型不是 jd(只处理微信、支付宝等其他来源)
|
||
// 2. 描述中包含"京东-订单编号"
|
||
// 3. 尚未被删除
|
||
filter := bson.M{
|
||
"bill_type": bson.M{"$ne": "jd"},
|
||
"description": bson.M{"$regex": "京东-订单编号", "$options": ""},
|
||
"is_deleted": bson.M{"$ne": true},
|
||
}
|
||
|
||
update := bson.M{
|
||
"$set": bson.M{
|
||
"is_deleted": true,
|
||
"updated_at": time.Now(),
|
||
},
|
||
}
|
||
|
||
result, err := r.cleanedCollection.UpdateMany(ctx, filter, update)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("soft delete JD related bills failed: %w", err)
|
||
}
|
||
|
||
return result.ModifiedCount, nil
|
||
}
|
||
|
||
// GetClient 获取 MongoDB 客户端(用于兼容旧代码)
|
||
func (r *Repository) GetClient() *mongo.Client {
|
||
return r.client
|
||
}
|
||
|
||
// GetDB 获取数据库实例(用于兼容旧代码)
|
||
func (r *Repository) GetDB() *mongo.Database {
|
||
return r.db
|
||
}
|
||
|
||
// GetRawCollection 获取原始数据集合(用于兼容旧代码)
|
||
func (r *Repository) GetRawCollection() *mongo.Collection {
|
||
return r.rawCollection
|
||
}
|
||
|
||
// GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码)
|
||
func (r *Repository) GetCleanedCollection() *mongo.Collection {
|
||
return r.cleanedCollection
|
||
}
|
||
|
||
// 确保 Repository 实现了 repository.BillRepository 接口
|
||
var _ repository.BillRepository = (*Repository)(nil)
|