chore(release): v1.0.7

- README/CHANGELOG: add v1.0.7 entry\n- Server: JWT expiry validated server-side (401 codes)\n- Web: logout/redirect on 401; proxy forwards Authorization\n- Server: bill service uses repository consistently
This commit is contained in:
CHE LIANG ZHAO
2026-01-16 11:15:05 +08:00
parent ad6a6d44ea
commit 3b7c1cd82b
17 changed files with 226 additions and 250 deletions

View File

@@ -1,7 +1,7 @@
# BillAI 服务器配置文件
# 应用版本
version: "1.0.6"
version: "1.0.7"
# 服务配置
server:

View File

@@ -145,7 +145,7 @@ func Load() {
flag.Parse()
// 设置默认值
Global.Version = "0.0.1"
Global.Version = "1.0.7"
Global.Port = getEnvOrDefault("PORT", "8080")
Global.ProjectRoot = getDefaultProjectRoot()
Global.PythonPath = getDefaultPythonPath()

View File

@@ -69,4 +69,3 @@ func Disconnect() error {
fmt.Println("🍃 MongoDB 连接已断开")
return nil
}

View File

@@ -3,6 +3,7 @@ package handler
import (
"crypto/sha256"
"encoding/hex"
"errors"
"net/http"
"time"
@@ -131,6 +132,7 @@ func ValidateToken(c *gin.Context) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"error": "未提供 Token",
"code": "TOKEN_MISSING",
})
return
}
@@ -147,12 +149,20 @@ func ValidateToken(c *gin.Context) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
if err != nil || !token.Valid {
code := "TOKEN_INVALID"
message := "Token 无效"
if err != nil && errors.Is(err, jwt.ErrTokenExpired) {
code = "TOKEN_EXPIRED"
message = "Token 已过期"
}
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"error": "Token 无效或已过期",
"error": message,
"code": code,
})
return
}
@@ -162,6 +172,7 @@ func ValidateToken(c *gin.Context) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"error": "Token 解析失败",
"code": "TOKEN_INVALID",
})
return
}

View File

@@ -69,4 +69,3 @@ func Review(c *gin.Context) {
},
})
}

View File

@@ -12,7 +12,6 @@ import (
adapterHttp "billai-server/adapter/http"
"billai-server/adapter/python"
"billai-server/config"
"billai-server/database"
"billai-server/repository"
repoMongo "billai-server/repository/mongo"
"billai-server/router"
@@ -44,21 +43,13 @@ func main() {
initAdapters()
// 初始化数据层
if err := initRepository(); err != nil {
repo, err := initRepository()
if err != nil {
fmt.Printf("⚠️ 警告: 数据层初始化失败: %v\n", err)
fmt.Println(" 账单数据将不会存储到数据库")
os.Exit(1)
}
// 连接 MongoDB保持兼容旧代码后续可移除
if err := database.Connect(); err != nil {
fmt.Printf("⚠️ 警告: MongoDB 连接失败: %v\n", err)
fmt.Println(" 账单数据将不会存储到数据库")
os.Exit(1)
} else {
// 优雅关闭时断开连接
defer database.Disconnect()
}
defer repo.Disconnect()
// 创建路由
r := gin.Default()
@@ -75,7 +66,7 @@ func main() {
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
fmt.Println("\n🛑 正在关闭服务...")
database.Disconnect()
repo.Disconnect()
os.Exit(0)
}()
@@ -153,14 +144,14 @@ func initAdapters() {
// initRepository 初始化数据存储层
// 在这里配置数据持久化方式
// 后续可以通过修改这里来切换不同的存储实现(如 PostgreSQL、MySQL 等)
func initRepository() error {
func initRepository() (repository.BillRepository, error) {
// 初始化 MongoDB 存储
mongoRepo := repoMongo.NewRepository()
if err := mongoRepo.Connect(); err != nil {
return err
return nil, err
}
repository.SetRepository(mongoRepo)
fmt.Println("💾 数据层初始化完成")
return nil
return mongoRepo, nil
}

75
server/middleware/auth.go Normal file
View File

@@ -0,0 +1,75 @@
package middleware
import (
"errors"
"net/http"
"strings"
"billai-server/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
// Claims JWT claims (duplicated here to avoid cross-package import from handler).
type Claims struct {
Username string `json:"username"`
Name string `json:"name"`
Role string `json:"role"`
jwt.RegisteredClaims
}
func AuthRequired() gin.HandlerFunc {
return func(c *gin.Context) {
tokenString := c.GetHeader("Authorization")
if tokenString == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"error": "未提供 Token",
"code": "TOKEN_MISSING",
})
c.Abort()
return
}
if strings.HasPrefix(tokenString, "Bearer ") {
tokenString = strings.TrimPrefix(tokenString, "Bearer ")
}
secret := config.Global.JWTSecret
if secret == "" {
secret = "billai-default-secret"
}
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
}, jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
if err != nil || !token.Valid {
code := "TOKEN_INVALID"
message := "Token 无效"
if err != nil && errors.Is(err, jwt.ErrTokenExpired) {
code = "TOKEN_EXPIRED"
message = "Token 已过期"
}
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"error": message,
"code": code,
})
c.Abort()
return
}
if claims, ok := token.Claims.(*Claims); ok {
c.Set("user", gin.H{
"username": claims.Username,
"name": claims.Name,
"role": claims.Role,
})
}
c.Next()
}
}

View File

@@ -43,4 +43,3 @@ type ReviewResponse struct {
Message string `json:"message"`
Data *ReviewData `json:"data,omitempty"`
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"billai-server/handler"
"billai-server/middleware"
)
// Config 路由配置参数
@@ -45,22 +46,27 @@ func setupAPIRoutes(r *gin.Engine) {
api.POST("/auth/login", handler.Login)
api.GET("/auth/validate", handler.ValidateToken)
// 账单上传
api.POST("/upload", handler.Upload)
// 需要登录的 API
authed := api.Group("/")
authed.Use(middleware.AuthRequired())
{
// 账单上传
authed.POST("/upload", handler.Upload)
// 复核相关
api.GET("/review", handler.Review)
// 复核相关
authed.GET("/review", handler.Review)
// 账单查询
api.GET("/bills", handler.ListBills)
// 账单查询
authed.GET("/bills", handler.ListBills)
// 手动创建账单
api.POST("/bills/manual", handler.CreateManualBills)
// 手动创建账单
authed.POST("/bills/manual", handler.CreateManualBills)
// 月度统计(全部数据)
api.GET("/monthly-stats", handler.MonthlyStats)
// 月度统计(全部数据)
authed.GET("/monthly-stats", handler.MonthlyStats)
// 待复核数据统计
api.GET("/review-stats", handler.ReviewStats)
// 待复核数据统计
authed.GET("/review-stats", handler.ReviewStats)
}
}
}

View File

@@ -1,7 +1,8 @@
package service
import (
"context"
"billai-server/model"
"billai-server/repository"
"encoding/csv"
"encoding/json"
"fmt"
@@ -9,11 +10,6 @@ import (
"strconv"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
"billai-server/database"
"billai-server/model"
)
// SaveResult 存储结果
@@ -23,29 +19,8 @@ type SaveResult struct {
DuplicateCount int // 重复数据跳过数量
}
// checkDuplicate 检查记录是否重复
// 优先使用 transaction_id 判断,如果为空则使用 时间+金额+商户 组合判断
func checkDuplicate(ctx context.Context, bill *model.CleanedBill) bool {
var filter bson.M
if bill.TransactionID != "" {
// 优先用交易订单号判断
filter = bson.M{"transaction_id": bill.TransactionID}
} else {
// 回退到 时间+金额+商户 组合判断
filter = bson.M{
"time": bill.Time.Time(), // 转换为 time.Time 用于 MongoDB 查询
"amount": bill.Amount,
"merchant": bill.Merchant,
}
}
count, err := database.CleanedBillCollection.CountDocuments(ctx, filter)
if err != nil {
return false // 查询出错时不认为是重复
}
return count > 0
func getRepo() repository.BillRepository {
return repository.GetRepository()
}
// DeduplicateResult 去重结果
@@ -60,6 +35,11 @@ type DeduplicateResult struct {
// DeduplicateRawFile 对原始文件进行去重检查,返回去重后的文件路径
// 如果全部重复,返回错误
func DeduplicateRawFile(filePath, uploadBatch string) (*DeduplicateResult, error) {
repo := getRepo()
if repo == nil {
return nil, fmt.Errorf("数据库未连接")
}
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("打开文件失败: %w", err)
@@ -94,10 +74,6 @@ func DeduplicateRawFile(filePath, uploadBatch string) (*DeduplicateResult, error
return result, nil
}
// 创建上下文
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 检查每行是否重复
var newRows [][]string
for _, row := range dataRows {
@@ -112,17 +88,14 @@ func DeduplicateRawFile(filePath, uploadBatch string) (*DeduplicateResult, error
continue
}
// 检查是否已存在
count, err := database.RawBillCollection.CountDocuments(ctx, bson.M{
"raw_data." + header[idFieldIdx]: transactionID,
})
isDup, err := repo.CheckRawDuplicate(header[idFieldIdx], transactionID)
if err != nil {
// 查询出错,保留该行
newRows = append(newRows, row)
continue
}
if count == 0 {
if !isDup {
// 不重复,保留
newRows = append(newRows, row)
} else {
@@ -198,6 +171,11 @@ func detectBillTypeAndIdField(header []string) (billType string, idFieldIdx int)
// SaveRawBillsFromFile 从原始上传文件读取数据并存入原始数据集合
func SaveRawBillsFromFile(filePath, billType, sourceFile, uploadBatch string) (int, error) {
repo := getRepo()
if repo == nil {
return 0, fmt.Errorf("数据库未连接")
}
file, err := os.Open(filePath)
if err != nil {
return 0, fmt.Errorf("打开文件失败: %w", err)
@@ -219,7 +197,7 @@ func SaveRawBillsFromFile(filePath, billType, sourceFile, uploadBatch string) (i
now := time.Now()
// 构建原始数据文档
var rawBills []interface{}
var rawBills []model.RawBill
for rowIdx, row := range rows[1:] {
rawData := make(map[string]interface{})
for colIdx, col := range header {
@@ -244,16 +222,7 @@ func SaveRawBillsFromFile(filePath, billType, sourceFile, uploadBatch string) (i
return 0, nil
}
// 批量插入原始数据集合
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
result, err := database.RawBillCollection.InsertMany(ctx, rawBills)
if err != nil {
return 0, fmt.Errorf("插入原始数据失败: %w", err)
}
return len(result.InsertedIDs), nil
return repo.SaveRawBills(rawBills)
}
// SaveCleanedBillsFromFile 从清洗后的文件读取数据并存入清洗后数据集合
@@ -268,6 +237,11 @@ func SaveCleanedBillsFromFile(filePath, format, billType, sourceFile, uploadBatc
// saveCleanedBillsFromCSV 从 CSV 文件读取并存储清洗后账单
// 返回: (插入数量, 重复跳过数量, 错误)
func saveCleanedBillsFromCSV(filePath, billType, sourceFile, uploadBatch string) (int, int, error) {
repo := getRepo()
if repo == nil {
return 0, 0, fmt.Errorf("数据库未连接")
}
file, err := os.Open(filePath)
if err != nil {
return 0, 0, fmt.Errorf("打开文件失败: %w", err)
@@ -291,13 +265,8 @@ func saveCleanedBillsFromCSV(filePath, billType, sourceFile, uploadBatch string)
colIdx[col] = i
}
// 创建上下文用于去重检查
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// 解析数据行
var bills []interface{}
duplicateCount := 0
var bills []model.CleanedBill
now := time.Now()
for _, row := range rows[1:] {
@@ -361,31 +330,24 @@ func saveCleanedBillsFromCSV(filePath, billType, sourceFile, uploadBatch string)
bill.ReviewLevel = row[idx]
}
// 检查是否重复
if checkDuplicate(ctx, &bill) {
duplicateCount++
continue // 跳过重复记录
}
bills = append(bills, bill)
}
if len(bills) == 0 {
return 0, duplicateCount, nil
}
// 批量插入清洗后数据集合
result, err := database.CleanedBillCollection.InsertMany(ctx, bills)
saved, duplicates, err := repo.SaveCleanedBills(bills)
if err != nil {
return 0, duplicateCount, fmt.Errorf("插入清洗后数据失败: %w", err)
return 0, 0, err
}
return len(result.InsertedIDs), duplicateCount, nil
return saved, duplicates, nil
}
// saveCleanedBillsFromJSON 从 JSON 文件读取并存储清洗后账单
// 返回: (插入数量, 重复跳过数量, 错误)
func saveCleanedBillsFromJSON(filePath, billType, sourceFile, uploadBatch string) (int, int, error) {
repo := getRepo()
if repo == nil {
return 0, 0, fmt.Errorf("数据库未连接")
}
file, err := os.Open(filePath)
if err != nil {
return 0, 0, fmt.Errorf("打开文件失败: %w", err)
@@ -402,13 +364,8 @@ func saveCleanedBillsFromJSON(filePath, billType, sourceFile, uploadBatch string
return 0, 0, nil
}
// 创建上下文用于去重检查
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// 解析数据
var bills []interface{}
duplicateCount := 0
var bills []model.CleanedBill
now := time.Now()
for _, item := range data {
@@ -467,25 +424,14 @@ func saveCleanedBillsFromJSON(filePath, billType, sourceFile, uploadBatch string
bill.ReviewLevel = v
}
// 检查是否重复
if checkDuplicate(ctx, &bill) {
duplicateCount++
continue // 跳过重复记录
}
bills = append(bills, bill)
}
if len(bills) == 0 {
return 0, duplicateCount, nil
}
result, err := database.CleanedBillCollection.InsertMany(ctx, bills)
saved, duplicates, err := repo.SaveCleanedBills(bills)
if err != nil {
return 0, duplicateCount, fmt.Errorf("插入清洗后数据失败: %w", err)
return 0, 0, err
}
return len(result.InsertedIDs), duplicateCount, nil
return saved, duplicates, nil
}
// parseTime 解析时间字符串
@@ -559,106 +505,3 @@ func parseAmount(s string) float64 {
}
return 0
}
// GetCleanedBillsByBatch 根据批次获取清洗后账单
func GetCleanedBillsByBatch(uploadBatch string) ([]model.CleanedBill, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cursor, err := database.CleanedBillCollection.Find(ctx, bson.M{"upload_batch": uploadBatch})
if err != nil {
return nil, fmt.Errorf("查询失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.CleanedBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, fmt.Errorf("解析结果失败: %w", err)
}
return bills, nil
}
// GetRawBillsByBatch 根据批次获取原始账单
func GetRawBillsByBatch(uploadBatch string) ([]model.RawBill, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cursor, err := database.RawBillCollection.Find(ctx, bson.M{"upload_batch": uploadBatch})
if err != nil {
return nil, fmt.Errorf("查询失败: %w", err)
}
defer cursor.Close(ctx)
var bills []model.RawBill
if err := cursor.All(ctx, &bills); err != nil {
return nil, fmt.Errorf("解析结果失败: %w", err)
}
return bills, nil
}
// GetBillStats 获取账单统计信息
func GetBillStats() (map[string]interface{}, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 原始数据总数
rawTotal, err := database.RawBillCollection.CountDocuments(ctx, bson.M{})
if err != nil {
return nil, err
}
// 清洗后数据总数
cleanedTotal, err := database.CleanedBillCollection.CountDocuments(ctx, bson.M{})
if err != nil {
return nil, err
}
// 支出总额(从清洗后数据统计)
expensePipeline := []bson.M{
{"$match": bson.M{"income_expense": "支出"}},
{"$group": bson.M{"_id": nil, "total": bson.M{"$sum": "$amount"}}},
}
expenseCursor, err := database.CleanedBillCollection.Aggregate(ctx, expensePipeline)
if err != nil {
return nil, err
}
defer expenseCursor.Close(ctx)
var expenseResult []bson.M
expenseCursor.All(ctx, &expenseResult)
totalExpense := 0.0
if len(expenseResult) > 0 {
if v, ok := expenseResult[0]["total"].(float64); ok {
totalExpense = v
}
}
// 收入总额(从清洗后数据统计)
incomePipeline := []bson.M{
{"$match": bson.M{"income_expense": "收入"}},
{"$group": bson.M{"_id": nil, "total": bson.M{"$sum": "$amount"}}},
}
incomeCursor, err := database.CleanedBillCollection.Aggregate(ctx, incomePipeline)
if err != nil {
return nil, err
}
defer incomeCursor.Close(ctx)
var incomeResult []bson.M
incomeCursor.All(ctx, &incomeResult)
totalIncome := 0.0
if len(incomeResult) > 0 {
if v, ok := incomeResult[0]["total"].(float64); ok {
totalIncome = v
}
}
return map[string]interface{}{
"raw_records": rawTotal,
"cleaned_records": cleanedTotal,
"total_expense": totalExpense,
"total_income": totalIncome,
}, nil
}

View File

@@ -131,4 +131,3 @@ func extractFromJSON(filePath string) []model.ReviewRecord {
return records
}