Files
billai/server/repository/mongo/repository.go
cheliangzhao eb76c3a8dc fix: 修复微信账单金额解析问题(半角¥符号支持)
- 修复 parse_amount 函数同时支持全角¥和半角¥
- 新增 MonthRangePicker 日期选择组件
- 新增 /api/monthly-stats 接口获取月度统计
- 分析页面月度趋势使用全量数据
- 新增健康检查路由
2026-01-10 19:21:24 +08:00

408 lines
11 KiB
Go

// Package mongo 实现基于 MongoDB 的账单存储
package mongo
import (
"context"
"fmt"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"billai-server/config"
"billai-server/model"
"billai-server/repository"
)
// Repository MongoDB 账单存储实现
type Repository struct {
client *mongo.Client
db *mongo.Database
rawCollection *mongo.Collection
cleanedCollection *mongo.Collection
}
// NewRepository 创建 MongoDB 存储实例
func NewRepository() *Repository {
return &Repository{}
}
// Connect 连接 MongoDB
func (r *Repository) Connect() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 创建客户端选项
clientOptions := options.Client().ApplyURI(config.Global.MongoURI)
// 连接 MongoDB
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return fmt.Errorf("连接 MongoDB 失败: %w", err)
}
// 测试连接
if err := client.Ping(ctx, nil); err != nil {
return fmt.Errorf("MongoDB Ping 失败: %w", err)
}
// 设置实例变量
r.client = client
r.db = client.Database(config.Global.MongoDatabase)
r.rawCollection = r.db.Collection(config.Global.MongoRawCollection)
r.cleanedCollection = r.db.Collection(config.Global.MongoCleanedCollection)
fmt.Printf("🍃 MongoDB 连接成功: %s\n", config.Global.MongoDatabase)
fmt.Printf(" 📄 原始数据集合: %s\n", config.Global.MongoRawCollection)
fmt.Printf(" 📄 清洗数据集合: %s\n", config.Global.MongoCleanedCollection)
return nil
}
// Disconnect 断开 MongoDB 连接
func (r *Repository) Disconnect() error {
if r.client == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := r.client.Disconnect(ctx); err != nil {
return fmt.Errorf("断开 MongoDB 连接失败: %w", err)
}
fmt.Println("🍃 MongoDB 连接已断开")
return nil
}
// SaveRawBills 保存原始账单数据
func (r *Repository) SaveRawBills(bills []model.RawBill) (int, error) {
if len(bills) == 0 {
return 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换为 interface{} 切片
docs := make([]interface{}, len(bills))
for i, bill := range bills {
docs[i] = bill
}
result, err := r.rawCollection.InsertMany(ctx, docs)
if err != nil {
return 0, fmt.Errorf("插入原始账单失败: %w", err)
}
return len(result.InsertedIDs), nil
}
// SaveCleanedBills 保存清洗后的账单数据(带去重)
func (r *Repository) SaveCleanedBills(bills []model.CleanedBill) (saved int, duplicates int, err error) {
if len(bills) == 0 {
return 0, 0, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, bill := range bills {
// 检查是否重复
isDup, err := r.CheckCleanedDuplicate(&bill)
if err != nil {
return saved, duplicates, err
}
if isDup {
duplicates++
continue
}
// 插入新记录
_, err = r.cleanedCollection.InsertOne(ctx, bill)
if err != nil {
return saved, duplicates, fmt.Errorf("插入清洗后账单失败: %w", err)
}
saved++
}
return saved, duplicates, nil
}
// CheckRawDuplicate 检查原始数据是否重复
func (r *Repository) CheckRawDuplicate(fieldName, value string) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
filter := bson.M{"raw_data." + fieldName: value}
count, err := r.rawCollection.CountDocuments(ctx, filter)
if err != nil {
return false, err
}
return count > 0, nil
}
// CheckCleanedDuplicate 检查清洗后数据是否重复
func (r *Repository) CheckCleanedDuplicate(bill *model.CleanedBill) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var filter bson.M
if bill.TransactionID != "" {
// 优先用交易订单号判断
filter = bson.M{"transaction_id": bill.TransactionID}
} else {
// 回退到 时间+金额+商户 组合判断
filter = bson.M{
"time": bill.Time,
"amount": bill.Amount,
"merchant": bill.Merchant,
}
}
count, err := r.cleanedCollection.CountDocuments(ctx, filter)
if err != nil {
return false, err
}
return count > 0, nil
}
// CountRawByField 按字段统计原始数据数量
func (r *Repository) CountRawByField(fieldName, value string) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
filter := bson.M{"raw_data." + fieldName: value}
return r.rawCollection.CountDocuments(ctx, filter)
}
// GetCleanedBills 获取清洗后的账单列表
func (r *Repository) GetCleanedBills(filter map[string]interface{}) ([]model.CleanedBill, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 按时间倒序排列
opts := options.Find().SetSort(bson.D{{Key: "time", Value: -1}})
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
if err != nil {
return nil, fmt.Errorf("查询账单失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, fmt.Errorf("解析账单数据失败: %w", err)
}
return bills, nil
}
// GetCleanedBillsPaged 获取清洗后的账单列表(带分页)
func (r *Repository) GetCleanedBillsPaged(filter map[string]interface{}, page, pageSize int) ([]model.CleanedBill, int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 计算总数
total, err := r.cleanedCollection.CountDocuments(ctx, bsonFilter)
if err != nil {
return nil, 0, fmt.Errorf("统计账单数量失败: %w", err)
}
// 计算跳过数量
skip := int64((page - 1) * pageSize)
// 查询选项:分页 + 按时间倒序
opts := options.Find().
SetSort(bson.D{{Key: "time", Value: -1}}).
SetSkip(skip).
SetLimit(int64(pageSize))
cursor, err := r.cleanedCollection.Find(ctx, bsonFilter, opts)
if err != nil {
return nil, 0, fmt.Errorf("查询账单失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, 0, fmt.Errorf("解析账单数据失败: %w", err)
}
return bills, total, nil
}
// GetBillsAggregate 获取账单聚合统计(总收入、总支出)
func (r *Repository) GetBillsAggregate(filter map[string]interface{}) (totalExpense float64, totalIncome float64, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 转换 filter
bsonFilter := bson.M{}
for k, v := range filter {
bsonFilter[k] = v
}
// 使用聚合管道按 income_expense 分组统计金额
pipeline := mongo.Pipeline{
{{Key: "$match", Value: bsonFilter}},
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$income_expense"},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}},
}
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
if err != nil {
return 0, 0, fmt.Errorf("聚合统计失败: %w", err)
}
defer cursor.Close(ctx)
// 解析结果
var results []struct {
ID string `bson:"_id"`
Total float64 `bson:"total"`
}
if err := cursor.All(ctx, &results); err != nil {
return 0, 0, fmt.Errorf("解析聚合结果失败: %w", err)
}
for _, result := range results {
switch result.ID {
case "支出":
totalExpense = result.Total
case "收入":
totalIncome = result.Total
}
}
return totalExpense, totalIncome, nil
}
// GetBillsNeedReview 获取需要复核的账单
func (r *Repository) GetBillsNeedReview() ([]model.CleanedBill, error) {
filter := map[string]interface{}{
"review_level": bson.M{"$in": []string{"HIGH", "LOW"}},
}
return r.GetCleanedBills(filter)
}
// GetMonthlyStats 获取月度统计(全部数据,不受筛选条件影响)
func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 使用聚合管道按月份分组统计
// 先按月份和收支类型分组,再汇总
pipeline := mongo.Pipeline{
// 添加月份字段
{{Key: "$addFields", Value: bson.D{
{Key: "month", Value: bson.D{
{Key: "$dateToString", Value: bson.D{
{Key: "format", Value: "%Y-%m"},
{Key: "date", Value: "$time"},
}},
}},
}}},
// 按月份和收支类型分组
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: bson.D{
{Key: "month", Value: "$month"},
{Key: "income_expense", Value: "$income_expense"},
}},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}},
// 按月份重新分组,汇总收入和支出
{{Key: "$group", Value: bson.D{
{Key: "_id", Value: "$_id.month"},
{Key: "expense", Value: bson.D{
{Key: "$sum", Value: bson.D{
{Key: "$cond", Value: bson.A{
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "支出"}}},
"$total",
0,
}},
}},
}},
{Key: "income", Value: bson.D{
{Key: "$sum", Value: bson.D{
{Key: "$cond", Value: bson.A{
bson.D{{Key: "$eq", Value: bson.A{"$_id.income_expense", "收入"}}},
"$total",
0,
}},
}},
}},
}}},
// 按月份排序
{{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := r.cleanedCollection.Aggregate(ctx, pipeline)
if err != nil {
return nil, fmt.Errorf("月度统计聚合失败: %w", err)
}
defer cursor.Close(ctx)
// 解析结果
var results []struct {
Month string `bson:"_id"`
Expense float64 `bson:"expense"`
Income float64 `bson:"income"`
}
if err := cursor.All(ctx, &results); err != nil {
return nil, fmt.Errorf("解析月度统计结果失败: %w", err)
}
// 转换为 MonthlyStat
stats := make([]model.MonthlyStat, len(results))
for i, r := range results {
stats[i] = model.MonthlyStat{
Month: r.Month,
Expense: r.Expense,
Income: r.Income,
}
}
return stats, nil
}
// GetClient 获取 MongoDB 客户端(用于兼容旧代码)
func (r *Repository) GetClient() *mongo.Client {
return r.client
}
// GetDB 获取数据库实例(用于兼容旧代码)
func (r *Repository) GetDB() *mongo.Database {
return r.db
}
// GetRawCollection 获取原始数据集合(用于兼容旧代码)
func (r *Repository) GetRawCollection() *mongo.Collection {
return r.rawCollection
}
// GetCleanedCollection 获取清洗后数据集合(用于兼容旧代码)
func (r *Repository) GetCleanedCollection() *mongo.Collection {
return r.cleanedCollection
}
// 确保 Repository 实现了 repository.BillRepository 接口
var _ repository.BillRepository = (*Repository)(nil)