294 lines
7.3 KiB
Go
294 lines
7.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
)
|
|
|
|
var db *sql.DB
|
|
|
|
// ==================== 数据库初始化 ====================
|
|
|
|
func initDB() error {
|
|
// 获取可执行文件所在目录
|
|
execPath, err := os.Executable()
|
|
if err != nil {
|
|
execPath = "."
|
|
}
|
|
dbPath := filepath.Join(filepath.Dir(execPath), "demo.db")
|
|
|
|
// 如果数据库不存在,使用当前目录
|
|
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
|
dbPath = "demo.db"
|
|
}
|
|
|
|
db, err = sql.Open("sqlite3", dbPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 创建示例表
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
email TEXT UNIQUE,
|
|
age INTEGER,
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS products (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
price REAL,
|
|
stock INTEGER DEFAULT 0
|
|
);
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 插入示例数据(如果表为空)
|
|
var count int
|
|
db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
|
if count == 0 {
|
|
_, err = db.Exec(`
|
|
INSERT INTO users (name, email, age) VALUES
|
|
('张三', 'zhangsan@example.com', 28),
|
|
('李四', 'lisi@example.com', 32),
|
|
('王五', 'wangwu@example.com', 25);
|
|
|
|
INSERT INTO products (name, price, stock) VALUES
|
|
('iPhone 15', 6999.00, 100),
|
|
('MacBook Pro', 14999.00, 50),
|
|
('AirPods Pro', 1899.00, 200);
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ==================== 工具输入结构 ====================
|
|
|
|
// ListTablesInput - 列出所有表
|
|
type ListTablesInput struct{}
|
|
|
|
// QueryInput - 执行 SQL 查询
|
|
type QueryInput struct {
|
|
SQL string `json:"sql" mcp:"SQL query to execute (SELECT only for safety)"`
|
|
}
|
|
|
|
// GetUserInput - 获取用户
|
|
type GetUserInput struct {
|
|
ID int `json:"id" mcp:"User ID to retrieve"`
|
|
}
|
|
|
|
// AddUserInput - 添加用户
|
|
type AddUserInput struct {
|
|
Name string `json:"name" mcp:"User name"`
|
|
Email string `json:"email" mcp:"User email address"`
|
|
Age int `json:"age" mcp:"User age"`
|
|
}
|
|
|
|
// DeleteUserInput - 删除用户
|
|
type DeleteUserInput struct {
|
|
ID int `json:"id" mcp:"User ID to delete"`
|
|
}
|
|
|
|
// ==================== 辅助函数 ====================
|
|
|
|
func textResult(text string) (*mcp.CallToolResult, any, error) {
|
|
return &mcp.CallToolResult{
|
|
Content: []mcp.Content{
|
|
&mcp.TextContent{Text: text},
|
|
},
|
|
}, nil, nil
|
|
}
|
|
|
|
func errorResult(text string) (*mcp.CallToolResult, any, error) {
|
|
return &mcp.CallToolResult{
|
|
Content: []mcp.Content{
|
|
&mcp.TextContent{Text: text},
|
|
},
|
|
IsError: true,
|
|
}, nil, nil
|
|
}
|
|
|
|
// ==================== 工具实现 ====================
|
|
|
|
// ListTables - 列出数据库中的所有表
|
|
func ListTables(ctx context.Context, req *mcp.CallToolRequest, input ListTablesInput) (*mcp.CallToolResult, any, error) {
|
|
rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
|
if err != nil {
|
|
return errorResult(fmt.Sprintf("Failed to list tables: %v", err))
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tables []string
|
|
for rows.Next() {
|
|
var name string
|
|
rows.Scan(&name)
|
|
tables = append(tables, name)
|
|
}
|
|
|
|
result, _ := json.MarshalIndent(tables, "", " ")
|
|
return textResult(fmt.Sprintf("Tables in database:\n%s", string(result)))
|
|
}
|
|
|
|
// QueryDB - 执行 SQL 查询
|
|
func QueryDB(ctx context.Context, req *mcp.CallToolRequest, input QueryInput) (*mcp.CallToolResult, any, error) {
|
|
// 安全检查:只允许 SELECT 查询
|
|
sqlUpper := strings.ToUpper(strings.TrimSpace(input.SQL))
|
|
if !strings.HasPrefix(sqlUpper, "SELECT") {
|
|
return errorResult("Only SELECT queries are allowed for safety")
|
|
}
|
|
|
|
rows, err := db.Query(input.SQL)
|
|
if err != nil {
|
|
return errorResult(fmt.Sprintf("Query failed: %v", err))
|
|
}
|
|
defer rows.Close()
|
|
|
|
// 获取列名
|
|
columns, _ := rows.Columns()
|
|
|
|
// 读取所有行
|
|
var results []map[string]interface{}
|
|
for rows.Next() {
|
|
values := make([]interface{}, len(columns))
|
|
valuePtrs := make([]interface{}, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
rows.Scan(valuePtrs...)
|
|
|
|
row := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
row[col] = values[i]
|
|
}
|
|
results = append(results, row)
|
|
}
|
|
|
|
jsonResult, _ := json.MarshalIndent(results, "", " ")
|
|
return textResult(fmt.Sprintf("Query results (%d rows):\n%s", len(results), string(jsonResult)))
|
|
}
|
|
|
|
// GetUser - 获取单个用户
|
|
func GetUser(ctx context.Context, req *mcp.CallToolRequest, input GetUserInput) (*mcp.CallToolResult, any, error) {
|
|
var id int
|
|
var name, email string
|
|
var age int
|
|
var createdAt string
|
|
|
|
err := db.QueryRow("SELECT id, name, email, age, created_at FROM users WHERE id = ?", input.ID).
|
|
Scan(&id, &name, &email, &age, &createdAt)
|
|
if err == sql.ErrNoRows {
|
|
return errorResult(fmt.Sprintf("User with ID %d not found", input.ID))
|
|
}
|
|
if err != nil {
|
|
return errorResult(fmt.Sprintf("Query failed: %v", err))
|
|
}
|
|
|
|
user := map[string]interface{}{
|
|
"id": id,
|
|
"name": name,
|
|
"email": email,
|
|
"age": age,
|
|
"created_at": createdAt,
|
|
}
|
|
|
|
jsonResult, _ := json.MarshalIndent(user, "", " ")
|
|
return textResult(string(jsonResult))
|
|
}
|
|
|
|
// AddUser - 添加用户
|
|
func AddUser(ctx context.Context, req *mcp.CallToolRequest, input AddUserInput) (*mcp.CallToolResult, any, error) {
|
|
result, err := db.Exec("INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
|
|
input.Name, input.Email, input.Age)
|
|
if err != nil {
|
|
return errorResult(fmt.Sprintf("Failed to add user: %v", err))
|
|
}
|
|
|
|
id, _ := result.LastInsertId()
|
|
return textResult(fmt.Sprintf("User added successfully with ID: %d", id))
|
|
}
|
|
|
|
// DeleteUser - 删除用户
|
|
func DeleteUser(ctx context.Context, req *mcp.CallToolRequest, input DeleteUserInput) (*mcp.CallToolResult, any, error) {
|
|
result, err := db.Exec("DELETE FROM users WHERE id = ?", input.ID)
|
|
if err != nil {
|
|
return errorResult(fmt.Sprintf("Failed to delete user: %v", err))
|
|
}
|
|
|
|
affected, _ := result.RowsAffected()
|
|
if affected == 0 {
|
|
return errorResult(fmt.Sprintf("User with ID %d not found", input.ID))
|
|
}
|
|
|
|
return textResult(fmt.Sprintf("User with ID %d deleted successfully", input.ID))
|
|
}
|
|
|
|
// ==================== 主函数 ====================
|
|
|
|
func main() {
|
|
// 初始化数据库
|
|
if err := initDB(); err != nil {
|
|
log.Fatalf("Failed to initialize database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
// 创建 MCP Server
|
|
server := mcp.NewServer(
|
|
&mcp.Implementation{
|
|
Name: "mcp-demo-go",
|
|
Version: "1.0.0",
|
|
},
|
|
nil,
|
|
)
|
|
|
|
// 注册工具
|
|
mcp.AddTool(server, &mcp.Tool{
|
|
Name: "list_tables",
|
|
Description: "List all tables in the SQLite database",
|
|
}, ListTables)
|
|
|
|
mcp.AddTool(server, &mcp.Tool{
|
|
Name: "query_db",
|
|
Description: "Execute a SELECT SQL query on the database",
|
|
}, QueryDB)
|
|
|
|
mcp.AddTool(server, &mcp.Tool{
|
|
Name: "get_user",
|
|
Description: "Get a user by ID from the users table",
|
|
}, GetUser)
|
|
|
|
mcp.AddTool(server, &mcp.Tool{
|
|
Name: "add_user",
|
|
Description: "Add a new user to the database",
|
|
}, AddUser)
|
|
|
|
mcp.AddTool(server, &mcp.Tool{
|
|
Name: "delete_user",
|
|
Description: "Delete a user from the database by ID",
|
|
}, DeleteUser)
|
|
|
|
// 运行服务器
|
|
log.Println("MCP Demo Go Server (SQLite) is running...")
|
|
if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
|
|
log.Fatalf("Server error: %v", err)
|
|
}
|
|
}
|