Initial commit
This commit is contained in:
293
golang/main.go
Normal file
293
golang/main.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
// ==================== 数据库初始化 ====================
|
||||
|
||||
func initDB() error {
|
||||
// 获取可执行文件所在目录
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
execPath = "."
|
||||
}
|
||||
dbPath := filepath.Join(filepath.Dir(execPath), "demo.db")
|
||||
|
||||
// 如果数据库不存在,使用当前目录
|
||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||
dbPath = "demo.db"
|
||||
}
|
||||
|
||||
db, err = sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建示例表
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE,
|
||||
age INTEGER,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS products (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
price REAL,
|
||||
stock INTEGER DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 插入示例数据(如果表为空)
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||
if count == 0 {
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO users (name, email, age) VALUES
|
||||
('张三', 'zhangsan@example.com', 28),
|
||||
('李四', 'lisi@example.com', 32),
|
||||
('王五', 'wangwu@example.com', 25);
|
||||
|
||||
INSERT INTO products (name, price, stock) VALUES
|
||||
('iPhone 15', 6999.00, 100),
|
||||
('MacBook Pro', 14999.00, 50),
|
||||
('AirPods Pro', 1899.00, 200);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 工具输入结构 ====================
|
||||
|
||||
// ListTablesInput - 列出所有表
|
||||
type ListTablesInput struct{}
|
||||
|
||||
// QueryInput - 执行 SQL 查询
|
||||
type QueryInput struct {
|
||||
SQL string `json:"sql" mcp:"SQL query to execute (SELECT only for safety)"`
|
||||
}
|
||||
|
||||
// GetUserInput - 获取用户
|
||||
type GetUserInput struct {
|
||||
ID int `json:"id" mcp:"User ID to retrieve"`
|
||||
}
|
||||
|
||||
// AddUserInput - 添加用户
|
||||
type AddUserInput struct {
|
||||
Name string `json:"name" mcp:"User name"`
|
||||
Email string `json:"email" mcp:"User email address"`
|
||||
Age int `json:"age" mcp:"User age"`
|
||||
}
|
||||
|
||||
// DeleteUserInput - 删除用户
|
||||
type DeleteUserInput struct {
|
||||
ID int `json:"id" mcp:"User ID to delete"`
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
func textResult(text string) (*mcp.CallToolResult, any, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: text},
|
||||
},
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
func errorResult(text string) (*mcp.CallToolResult, any, error) {
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: text},
|
||||
},
|
||||
IsError: true,
|
||||
}, nil, nil
|
||||
}
|
||||
|
||||
// ==================== 工具实现 ====================
|
||||
|
||||
// ListTables - 列出数据库中的所有表
|
||||
func ListTables(ctx context.Context, req *mcp.CallToolRequest, input ListTablesInput) (*mcp.CallToolResult, any, error) {
|
||||
rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("Failed to list tables: %v", err))
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
rows.Scan(&name)
|
||||
tables = append(tables, name)
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(tables, "", " ")
|
||||
return textResult(fmt.Sprintf("Tables in database:\n%s", string(result)))
|
||||
}
|
||||
|
||||
// QueryDB - 执行 SQL 查询
|
||||
func QueryDB(ctx context.Context, req *mcp.CallToolRequest, input QueryInput) (*mcp.CallToolResult, any, error) {
|
||||
// 安全检查:只允许 SELECT 查询
|
||||
sqlUpper := strings.ToUpper(strings.TrimSpace(input.SQL))
|
||||
if !strings.HasPrefix(sqlUpper, "SELECT") {
|
||||
return errorResult("Only SELECT queries are allowed for safety")
|
||||
}
|
||||
|
||||
rows, err := db.Query(input.SQL)
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("Query failed: %v", err))
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// 获取列名
|
||||
columns, _ := rows.Columns()
|
||||
|
||||
// 读取所有行
|
||||
var results []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
rows.Scan(valuePtrs...)
|
||||
|
||||
row := make(map[string]interface{})
|
||||
for i, col := range columns {
|
||||
row[col] = values[i]
|
||||
}
|
||||
results = append(results, row)
|
||||
}
|
||||
|
||||
jsonResult, _ := json.MarshalIndent(results, "", " ")
|
||||
return textResult(fmt.Sprintf("Query results (%d rows):\n%s", len(results), string(jsonResult)))
|
||||
}
|
||||
|
||||
// GetUser - 获取单个用户
|
||||
func GetUser(ctx context.Context, req *mcp.CallToolRequest, input GetUserInput) (*mcp.CallToolResult, any, error) {
|
||||
var id int
|
||||
var name, email string
|
||||
var age int
|
||||
var createdAt string
|
||||
|
||||
err := db.QueryRow("SELECT id, name, email, age, created_at FROM users WHERE id = ?", input.ID).
|
||||
Scan(&id, &name, &email, &age, &createdAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return errorResult(fmt.Sprintf("User with ID %d not found", input.ID))
|
||||
}
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("Query failed: %v", err))
|
||||
}
|
||||
|
||||
user := map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"email": email,
|
||||
"age": age,
|
||||
"created_at": createdAt,
|
||||
}
|
||||
|
||||
jsonResult, _ := json.MarshalIndent(user, "", " ")
|
||||
return textResult(string(jsonResult))
|
||||
}
|
||||
|
||||
// AddUser - 添加用户
|
||||
func AddUser(ctx context.Context, req *mcp.CallToolRequest, input AddUserInput) (*mcp.CallToolResult, any, error) {
|
||||
result, err := db.Exec("INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
|
||||
input.Name, input.Email, input.Age)
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("Failed to add user: %v", err))
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
return textResult(fmt.Sprintf("User added successfully with ID: %d", id))
|
||||
}
|
||||
|
||||
// DeleteUser - 删除用户
|
||||
func DeleteUser(ctx context.Context, req *mcp.CallToolRequest, input DeleteUserInput) (*mcp.CallToolResult, any, error) {
|
||||
result, err := db.Exec("DELETE FROM users WHERE id = ?", input.ID)
|
||||
if err != nil {
|
||||
return errorResult(fmt.Sprintf("Failed to delete user: %v", err))
|
||||
}
|
||||
|
||||
affected, _ := result.RowsAffected()
|
||||
if affected == 0 {
|
||||
return errorResult(fmt.Sprintf("User with ID %d not found", input.ID))
|
||||
}
|
||||
|
||||
return textResult(fmt.Sprintf("User with ID %d deleted successfully", input.ID))
|
||||
}
|
||||
|
||||
// ==================== 主函数 ====================
|
||||
|
||||
func main() {
|
||||
// 初始化数据库
|
||||
if err := initDB(); err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 创建 MCP Server
|
||||
server := mcp.NewServer(
|
||||
&mcp.Implementation{
|
||||
Name: "mcp-demo-go",
|
||||
Version: "1.0.0",
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
// 注册工具
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "list_tables",
|
||||
Description: "List all tables in the SQLite database",
|
||||
}, ListTables)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "query_db",
|
||||
Description: "Execute a SELECT SQL query on the database",
|
||||
}, QueryDB)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "get_user",
|
||||
Description: "Get a user by ID from the users table",
|
||||
}, GetUser)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "add_user",
|
||||
Description: "Add a new user to the database",
|
||||
}, AddUser)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "delete_user",
|
||||
Description: "Delete a user from the database by ID",
|
||||
}, DeleteUser)
|
||||
|
||||
// 运行服务器
|
||||
log.Println("MCP Demo Go Server (SQLite) is running...")
|
||||
if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user