- 新增 LocalTime 自定义类型,JSON序列化输出本地时间格式 - 修改 CleanedBill.Time 字段类型为 LocalTime - 更新 parseTime 函数返回 LocalTime 类型 - 前端添加 formatDateTime 工具函数(兼容处理) - 版本号更新至 1.0.2
408 lines
11 KiB
Go
408 lines
11 KiB
Go
// Package mongo 实现基于 MongoDB 的账单存储
|
|
package mongo
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.mongodb.org/mongo-driver/mongo/options"
|
|
|
|
"billai-server/config"
|
|
"billai-server/model"
|
|
"billai-server/repository"
|
|
)
|
|
|
|
// Repository MongoDB 账单存储实现
|
|
type Repository struct {
|
|
client *mongo.Client
|
|
db *mongo.Database
|
|
rawCollection *mongo.Collection
|
|
cleanedCollection *mongo.Collection
|
|
}
|
|
|
|
// NewRepository 创建 MongoDB 存储实例
|
|
func NewRepository() *Repository {
|
|
return &Repository{}
|
|
}
|
|
|
|
// Connect 连接 MongoDB
|
|
func (r *Repository) Connect() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
// 创建客户端选项
|
|
clientOptions := options.Client().ApplyURI(config.Global.MongoURI)
|
|
|
|
// 连接 MongoDB
|
|
client, err := mongo.Connect(ctx, clientOptions)
|
|
if err != nil {
|
|
return fmt.Errorf("连接 MongoDB 失败: %w", err)
|
|
}
|
|
|
|
// 测试连接
|
|
if err := client.Ping(ctx, nil); err != nil {
|
|
return fmt.Errorf("MongoDB Ping 失败: %w", err)
|
|
}
|
|
|
|
// 设置实例变量
|
|
r.client = client
|
|
r.db = client.Database(config.Global.MongoDatabase)
|
|
r.rawCollection = r.db.Collection(config.Global.MongoRawCollection)
|
|
r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection)
|
|
|
|
fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase)
|
|
fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection)
|
|
fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection)
|
|
return nil
|
|
}
|
|
|
|
// Disconnect 断开 MongoDB 连接
|
|
func (r *Repository) Disconnect() error {
|
|
if r.client == nil {
|
|
return nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
if err := r.client.Disconnect(ctx); err != nil {
|
|
return fmt.Errorf("断开 MongoDB 连接失败: %w", err)
|
|
}
|
|
|
|
fmt.Println("🍃 MongoDB 连接已断开")
|
|
return nil
|
|
}
|
|
|
|
// SaveRawBills 保存原始账单数据
|
|
func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) {
|
|
if len(bills) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// 转换为 interface{} 切片
|
|
docs := make([]interface{}, len(bills))
|
|
for i, bill := range bills {
|
|
docs[i] = bill
|
|
}
|
|
|
|
result, err := r.rawCollection.InsertMany(ctx, docs)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("插入原始账单失败: %w", err)
|
|
}
|
|
|
|
return len(result.InsertedIDs), nil
|
|
}
|
|
|
|
// SaveCleanedBills 保存清洗后的账单数据(带去重)
|
|
func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) {
|
|
if len(bills) == 0 {
|
|
return 0, 0, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
for _, bill := range bills {
|
|
// 检查是否重复
|
|
isDup, err := r.CheckCleanedDuplicate(&bill)
|
|
if err != nil {
|
|
return saved, duplicates, err
|
|
}
|
|
if isDup {
|
|
duplicates++
|
|
continue
|
|
}
|
|
|
|
// 插入新记录
|
|
_, err = r.cleanedCollection.InsertOne(ctx, bill)
|
|
if err != nil {
|
|
return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err)
|
|
}
|
|
saved++
|
|
}
|
|
|
|
return saved, duplicates, nil
|
|
}
|
|
|
|
// CheckRawDuplicate 检查原始数据是否重复
|
|
func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
filter := bson.M{"raw_data." + fieldName: value}
|
|
count, err := r.rawCollection.CountDocuments(ctx, filter)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
// CheckCleanedDuplicate 检查清洗后数据是否重复
|
|
func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
var filter bson.M
|
|
|
|
if bill.TransactionID != "" {
|
|
// 优先用交易订单号判断
|
|
filter = bson.M{"transaction_id": bill.TransactionID}
|
|
} else {
|
|
// 回退到 时间+金额+商户 组合判断
|
|
filter = bson.M{
|
|
"time": bill.Time.Time(), // 转换为 time.Time 用于 MongoDB 查询
|
|
"amount": bill.Amount,
|
|
"merchant": bill.Merchant,
|
|
}
|
|
}
|
|
|
|
count, err := r.cleanedCollection.CountDocuments(ctx, filter)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
// CountRawByField 按字段统计原始数据数量
|
|
func (r *Repository) CountRawByField(fieldName, value string) (int64, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
filter := bson.M{"raw_data." + fieldName: value}
|
|
return r.rawCollection.CountDocuments(ctx, filter)
|
|
}
|
|
|
|
// GetCleanedBills 获取清洗后的账单列表
|
|
func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// 转换 filter
|
|
bsonFilter := bson.M{}
|
|
for k, v := range filter {
|
|
bsonFilter[k] = v
|
|
}
|
|
|
|
// 按时间倒序排列
|
|
opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}})
|
|
|
|
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询账单失败: %w", err)
|
|
}
|
|
defer cursor.Close(ctx)
|
|
|
|
var bills []model.CleanedBill
|
|
if err := cursor.All(ctx, &bills); err != nil {
|
|
return nil, fmt.Errorf("解析账单数据失败: %w", err)
|
|
}
|
|
|
|
return bills, nil
|
|
}
|
|
|
|
// GetCleanedBillsPaged 获取清洗后的账单列表(带分页)
|
|
func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// 转换 filter
|
|
bsonFilter := bson.M{}
|
|
for k, v := range filter {
|
|
bsonFilter[k] = v
|
|
}
|
|
|
|
// 计算总数
|
|
total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("统计账单数量失败: %w", err)
|
|
}
|
|
|
|
// 计算跳过数量
|
|
skip := int64((page - 1) * pageSize)
|
|
|
|
// 查询选项:分页 + 按时间倒序
|
|
opts := options.Find().
|
|
SetSort(bson.D{{Key: "time", Value: -1}}).
|
|
SetSkip(skip).
|
|
SetLimit(int64(pageSize))
|
|
|
|
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("查询账单失败: %w", err)
|
|
}
|
|
defer cursor.Close(ctx)
|
|
|
|
var bills []model.CleanedBill
|
|
if err := cursor.All(ctx, &bills); err != nil {
|
|
return nil, 0, fmt.Errorf("解析账单数据失败: %w", err)
|
|
}
|
|
|
|
return bills, total, nil
|
|
}
|
|
|
|
// GetBillsAggregate 获取账单聚合统计(总收入、总支出)
|
|
func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// 转换 filter
|
|
bsonFilter := bson.M{}
|
|
for k, v := range filter {
|
|
bsonFilter[k] = v
|
|
}
|
|
|
|
// 使用聚合管道按 income_expense 分组统计金额
|
|
pipeline := mongo.Pipeline{
|
|
{{Key: "$match", Value: bsonFilter}},
|
|
{{Key: "$group", Value: bson.D{
|
|
{Key: "_id", Value: "$income_expense"},
|
|
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
|
|
}}},
|
|
}
|
|
|
|
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("聚合统计失败: %w", err)
|
|
}
|
|
defer cursor.Close(ctx)
|
|
|
|
// 解析结果
|
|
var results []struct {
|
|
ID string `bson:"_id"`
|
|
Total float64 `bson:"total"`
|
|
}
|
|
if err := cursor.All(ctx, &results); err != nil {
|
|
return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err)
|
|
}
|
|
|
|
for _, result := range results {
|
|
switch result.ID {
|
|
case "支出":
|
|
totalExpense = result.Total
|
|
case "收入":
|
|
totalIncome = result.Total
|
|
}
|
|
}
|
|
|
|
return totalExpense, totalIncome, nil
|
|
}
|
|
|
|
// GetBillsNeedReview 获取需要复核的账单
|
|
func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) {
|
|
filter := map[string]interface{}{
|
|
"review_level": bson.M{"$in": []string{"HIGH", "LOW"}},
|
|
}
|
|
return r.GetCleanedBills(filter)
|
|
}
|
|
|
|
// GetMonthlyStats 获取月度统计(全部数据,不受筛选条件影响)
|
|
func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
// 使用聚合管道按月份分组统计
|
|
// 先按月份和收支类型分组,再汇总
|
|
pipeline := mongo.Pipeline{
|
|
// 添加月份字段
|
|
{{Key: "$addFields", Value: bson.D{
|
|
{Key: "month", Value: bson.D{
|
|
{Key: "$dateToString", Value: bson.D{
|
|
{Key: "format", Value: "%Y-%m"},
|
|
{Key: "date", Value: "$time"},
|
|
}},
|
|
}},
|
|
}}},
|
|
// 按月份和收支类型分组
|
|
{{Key: "$group", Value: bson.D{
|
|
{Key: "_id", Value: bson.D{
|
|
{Key: "month", Value: "$month"},
|
|
{Key: "income_expense", Value: "$income_expense"},
|
|
}},
|
|
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
|
|
}}},
|
|
// 按月份重新分组,汇总收入和支出
|
|
{{Key: "$group", Value: bson.D{
|
|
{Key: "_id", Value: "$_id.month"},
|
|
{Key: "expense", Value: bson.D{
|
|
{Key: "$sum", Value: bson.D{
|
|
{Key: "$cond", Value: bson.A{
|
|
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "支出"}}},
|
|
"$total",
|
|
0,
|
|
}},
|
|
}},
|
|
}},
|
|
{Key: "income", Value: bson.D{
|
|
{Key: "$sum", Value: bson.D{
|
|
{Key: "$cond", Value: bson.A{
|
|
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "收入"}}},
|
|
"$total",
|
|
0,
|
|
}},
|
|
}},
|
|
}},
|
|
}}},
|
|
// 按月份排序
|
|
{{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}},
|
|
}
|
|
|
|
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("月度统计聚合失败: %w", err)
|
|
}
|
|
defer cursor.Close(ctx)
|
|
|
|
// 解析结果
|
|
var results []struct {
|
|
Month string `bson:"_id"`
|
|
Expense float64 `bson:"expense"`
|
|
Income float64 `bson:"income"`
|
|
}
|
|
if err := cursor.All(ctx, &results); err != nil {
|
|
return nil, fmt.Errorf("解析月度统计结果失败: %w", err)
|
|
}
|
|
|
|
// 转换为 MonthlyStat
|
|
stats := make([]model.MonthlyStat, len(results))
|
|
for i, r := range results {
|
|
stats[i] = model.MonthlyStat{
|
|
Month: r.Month,
|
|
Expense: r.Expense,
|
|
Income: r.Income,
|
|
}
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// GetClient 获取 MongoDB 客户端(用于兼容旧代码)
|
|
func (r *Repository) GetClient() *mongo.Client {
|
|
return r.client
|
|
}
|
|
|
|
// GetDB 获取数据库实例(用于兼容旧代码)
|
|
func (r *Repository) GetDB() *mongo.Database {
|
|
return r.db
|
|
}
|
|
|
|
// GetRawCollection 获取原始数据集合(用于兼容旧代码)
|
|
func (r *Repository) GetRawCollection() *mongo.Collection {
|
|
return r.rawCollection
|
|
}
|
|
|
|
// GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码)
|
|
func (r *Repository) GetCleanedCollection() *mongo.Collection {
|
|
return r.cleanedCollection
|
|
}
|
|
|
|
// 确保 Repository 实现了 repository.BillRepository 接口
|
|
var _ repository.BillRepository = (*Repository)(nil)
|