Skip to content

feat: enhance MCP server support with authentication and environment variables #1644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Add MCP (Model Context Protocol) host backend implementation
- Add MCP server models and schemas for storing server configurations
- Implement MCP connection manager supporting both stdio and HTTP+SSE transports
- Create REST API endpoints for managing MCP servers (CRUD operations)
- Add WebSocket endpoint for real-time MCP communication
- Implement tool calling, resource fetching, and prompt retrieval
- Add database migration for mcp_servers table with JSON fields
- Support both local (stdio) and remote (HTTP+SSE) MCP servers
- Add aiohttp dependency for HTTP client functionality

This backend implementation provides a foundation for browser-based MCP hosting,
allowing users to connect to and interact with MCP servers similar to Claude Desktop.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
  • Loading branch information
User and claude committed May 27, 2025
commit e6d0c3d333b24c8f99da2091b4ea2125daeaa5b8
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Add MCP server table

Revision ID: 2025_05_27_0402
Revises: 2025_05_26_1343
Create Date: 2025-05-27 04:02:50.892296

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import sqlmodel

# revision identifiers, used by Alembic.
revision: str = '2025_05_27_0402'
down_revision: Union[str, None] = '2025_05_26_1343'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Create enum type for transport
op.execute("CREATE TYPE mcptransporttype AS ENUM ('stdio', 'http_sse')")

# Create mcp_servers table
op.create_table('mcp_servers',
sa.Column('id', postgresql.UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.String(), nullable=True),
sa.Column('transport', sa.Enum('stdio', 'http_sse', name='mcptransporttype'), nullable=False),
sa.Column('command', sa.String(), nullable=True),
sa.Column('args', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('url', sa.String(), nullable=True),
sa.Column('headers', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('is_remote', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('capabilities', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('resources', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.Column('prompts', postgresql.JSON(astext_type=sa.Text()), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_mcp_servers_id'), 'mcp_servers', ['id'], unique=False)
op.create_index(op.f('ix_mcp_servers_name'), 'mcp_servers', ['name'], unique=True)


def downgrade() -> None:
# Drop table and indexes
op.drop_index(op.f('ix_mcp_servers_name'), table_name='mcp_servers')
op.drop_index(op.f('ix_mcp_servers_id'), table_name='mcp_servers')
op.drop_table('mcp_servers')

# Drop enum type
op.execute("DROP TYPE mcptransporttype")
3 changes: 2 additions & 1 deletion backend/app/api/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from app.api.routes import items, login, private, users, utils
from app.api.routes import items, login, private, users, utils, mcp
from app.api.routes.auth.router import router as auth_router
from app.core.config import settings

@@ -14,6 +14,7 @@
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(utils.router, prefix="/utils", tags=["utils"])
api_router.include_router(items.router, prefix="/items", tags=["items"])
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])

# Include private routes in local environment
if settings.ENVIRONMENT == "local":
268 changes: 268 additions & 0 deletions backend/app/api/routes/mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""API routes for MCP server management."""
from typing import List, Any
from fastapi import APIRouter, HTTPException, Depends
from sqlmodel import Session, select
from app.api.deps import get_current_active_user, get_current_active_superuser, get_db
from app.models import User, MCPServer, MCPServerCreate, MCPServerUpdate, MCPServerPublic
from app.services.mcp_manager import mcp_manager
import logging

logger = logging.getLogger(__name__)

router = APIRouter()


@router.get("/servers", response_model=List[MCPServerPublic])
async def list_mcp_servers(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
) -> List[MCPServerPublic]:
"""List all MCP servers."""
servers = session.exec(select(MCPServer)).all()

# Update runtime data from manager
result = []
for server in servers:
server_dict = server.model_dump()
if server.name in mcp_manager.connections:
connection = mcp_manager.connections[server.name]
server_dict["status"] = connection.status
server_dict["error_message"] = connection.error_message
else:
server_dict["status"] = MCPServerStatus.DISCONNECTED
server_dict["error_message"] = None
result.append(MCPServerPublic(**server_dict))

return result


@router.post("/servers", response_model=MCPServerPublic)
async def create_mcp_server(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser),
server_in: MCPServerCreate,
) -> MCPServerPublic:
"""Create a new MCP server configuration."""
# Check if server with same name exists
existing = session.exec(select(MCPServer).where(MCPServer.name == server_in.name)).first()
if existing:
raise HTTPException(status_code=400, detail="Server with this name already exists")

server = MCPServer.model_validate(server_in)
session.add(server)
session.commit()
session.refresh(server)

# Auto-connect if enabled
if server.is_enabled:
try:
await mcp_manager.connect_server(server)
except Exception as e:
logger.error(f"Failed to connect to server {server.name}: {e}")

return server


@router.get("/servers/{server_id}", response_model=MCPServerPublic)
async def get_mcp_server(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
server_id: str,
) -> MCPServerPublic:
"""Get MCP server by ID."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

# Update runtime data from manager
server_dict = server.model_dump()
if server.name in mcp_manager.connections:
connection = mcp_manager.connections[server.name]
server_dict["status"] = connection.status
server_dict["error_message"] = connection.error_message
else:
server_dict["status"] = MCPServerStatus.DISCONNECTED
server_dict["error_message"] = None

return MCPServerPublic(**server_dict)


@router.patch("/servers/{server_id}", response_model=MCPServerPublic)
async def update_mcp_server(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser),
server_id: str,
server_in: MCPServerUpdate,
) -> MCPServerPublic:
"""Update MCP server configuration."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

# Disconnect if connected
if server.name in mcp_manager.connections:
await mcp_manager.disconnect_server(server.name)

# Update server
update_data = server_in.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(server, key, value)

session.add(server)
session.commit()
session.refresh(server)

# Reconnect if enabled
if server.is_enabled:
try:
await mcp_manager.connect_server(server)
except Exception as e:
logger.error(f"Failed to connect to server {server.name}: {e}")

return server


@router.delete("/servers/{server_id}")
async def delete_mcp_server(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser),
server_id: str,
) -> dict:
"""Delete MCP server."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

# Disconnect if connected
if server.name in mcp_manager.connections:
await mcp_manager.disconnect_server(server.name)

session.delete(server)
session.commit()

return {"message": "Server deleted successfully"}


@router.post("/servers/{server_id}/connect")
async def connect_mcp_server(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
server_id: str,
) -> dict:
"""Connect to an MCP server."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

if not server.is_enabled:
raise HTTPException(status_code=400, detail="Server is disabled")

try:
success = await mcp_manager.connect_server(server)
if success:
# Update server with discovered capabilities
session.add(server)
session.commit()
return {"message": "Connected successfully", "status": mcp_manager.connections[server.name].status.value}
else:
connection = mcp_manager.connections.get(server.name)
error_msg = connection.error_message if connection else "Failed to connect"
raise HTTPException(status_code=500, detail=error_msg)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.post("/servers/{server_id}/disconnect")
async def disconnect_mcp_server(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
server_id: str,
) -> dict:
"""Disconnect from an MCP server."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

if server.name not in mcp_manager.connections:
raise HTTPException(status_code=400, detail="Server not connected")

await mcp_manager.disconnect_server(server.name)
return {"message": "Disconnected successfully"}


@router.post("/servers/{server_id}/tools/{tool_name}/call")
async def call_mcp_tool(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
server_id: str,
tool_name: str,
arguments: dict[str, Any],
) -> Any:
"""Call a tool on an MCP server."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

if server.name not in mcp_manager.connections:
raise HTTPException(status_code=400, detail="Server not connected")

try:
result = await mcp_manager.call_tool(server.name, tool_name, arguments)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.get("/servers/{server_id}/resources/{uri:path}")
async def get_mcp_resource(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
server_id: str,
uri: str,
) -> Any:
"""Get a resource from an MCP server."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

if server.name not in mcp_manager.connections:
raise HTTPException(status_code=400, detail="Server not connected")

try:
result = await mcp_manager.get_resource(server.name, uri)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.post("/servers/{server_id}/prompts/{prompt_name}")
async def get_mcp_prompt(
*,
session: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
server_id: str,
prompt_name: str,
arguments: dict[str, Any],
) -> Any:
"""Get a prompt from an MCP server."""
server = session.get(MCPServer, server_id)
if not server:
raise HTTPException(status_code=404, detail="Server not found")

if server.name not in mcp_manager.connections:
raise HTTPException(status_code=400, detail="Server not connected")

try:
result = await mcp_manager.get_prompt(server.name, prompt_name, arguments)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
224 changes: 224 additions & 0 deletions backend/app/api/routes/mcp_websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""WebSocket endpoint for real-time MCP communication."""
import json
import logging
from typing import Dict, Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
from sqlmodel import Session
from app.api.deps import get_db
from app.models import User
from app.services.mcp_manager import mcp_manager
from app.core.security import decode_token
from app.core.config import settings

logger = logging.getLogger(__name__)

router = APIRouter()


class ConnectionManager:
"""Manages WebSocket connections."""

def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}

async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
self.active_connections[user_id] = websocket

def disconnect(self, user_id: str):
self.active_connections.pop(user_id, None)

async def send_message(self, user_id: str, message: dict):
if user_id in self.active_connections:
websocket = self.active_connections[user_id]
await websocket.send_json(message)

async def broadcast(self, message: dict):
for websocket in self.active_connections.values():
await websocket.send_json(message)


manager = ConnectionManager()


async def get_current_user_from_token(token: str, session: Session) -> User:
"""Get current user from JWT token."""
try:
payload = decode_token(token)
user_id = payload.get("sub")
if not user_id:
raise ValueError("Invalid token")

user = session.get(User, user_id)
if not user or not user.is_active:
raise ValueError("User not found or inactive")

return user
except Exception as e:
logger.error(f"Token validation failed: {e}")
raise


@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(...),
session: Session = Depends(get_db),
):
"""WebSocket endpoint for real-time MCP communication."""
user = None
try:
# Authenticate user
user = await get_current_user_from_token(token, session)
await manager.connect(websocket, str(user.id))

# Send initial connection message
await websocket.send_json({
"type": "connection",
"status": "connected",
"user_id": str(user.id)
})

# Listen for messages
while True:
try:
data = await websocket.receive_json()
await handle_websocket_message(websocket, user, data, session)
except WebSocketDisconnect:
break
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"error": "Invalid JSON"
})
except Exception as e:
logger.error(f"WebSocket error: {e}")
await websocket.send_json({
"type": "error",
"error": str(e)
})

except Exception as e:
logger.error(f"WebSocket connection error: {e}")
await websocket.close(code=1008, reason="Authentication failed")
finally:
if user:
manager.disconnect(str(user.id))


async def handle_websocket_message(
websocket: WebSocket,
user: User,
data: Dict[str, Any],
session: Session
):
"""Handle incoming WebSocket messages."""
msg_type = data.get("type")

if msg_type == "ping":
await websocket.send_json({"type": "pong"})

elif msg_type == "tool_call":
# Call tool on MCP server
server_name = data.get("server")
tool_name = data.get("tool")
arguments = data.get("arguments", {})

try:
result = await mcp_manager.call_tool(server_name, tool_name, arguments)
await websocket.send_json({
"type": "tool_result",
"server": server_name,
"tool": tool_name,
"result": result
})
except Exception as e:
await websocket.send_json({
"type": "tool_error",
"server": server_name,
"tool": tool_name,
"error": str(e)
})

elif msg_type == "resource_get":
# Get resource from MCP server
server_name = data.get("server")
uri = data.get("uri")

try:
result = await mcp_manager.get_resource(server_name, uri)
await websocket.send_json({
"type": "resource_result",
"server": server_name,
"uri": uri,
"result": result
})
except Exception as e:
await websocket.send_json({
"type": "resource_error",
"server": server_name,
"uri": uri,
"error": str(e)
})

elif msg_type == "prompt_get":
# Get prompt from MCP server
server_name = data.get("server")
prompt_name = data.get("prompt")
arguments = data.get("arguments", {})

try:
result = await mcp_manager.get_prompt(server_name, prompt_name, arguments)
await websocket.send_json({
"type": "prompt_result",
"server": server_name,
"prompt": prompt_name,
"result": result
})
except Exception as e:
await websocket.send_json({
"type": "prompt_error",
"server": server_name,
"prompt": prompt_name,
"error": str(e)
})

elif msg_type == "server_status":
# Get status of all servers
from app.models import MCPServer
from sqlmodel import select

servers = session.exec(select(MCPServer)).all()
status_list = []

for server in servers:
status = {
"id": str(server.id),
"name": server.name,
"status": "disconnected",
"capabilities": None,
"tools": None,
"resources": None,
"prompts": None
}

if server.name in mcp_manager.connections:
connection = mcp_manager.connections[server.name]
status["status"] = connection.server.status.value
status["capabilities"] = connection.server.capabilities
status["tools"] = connection.server.tools
status["resources"] = connection.server.resources
status["prompts"] = connection.server.prompts

status_list.append(status)

await websocket.send_json({
"type": "server_status_result",
"servers": status_list
})

else:
await websocket.send_json({
"type": "error",
"error": f"Unknown message type: {msg_type}"
})
2 changes: 2 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from starlette.middleware.cors import CORSMiddleware

from app.api.main import api_router
from app.api.routes.mcp_websocket import router as mcp_ws_router
from app.core.config import settings


@@ -31,3 +32,4 @@ def custom_generate_unique_id(route: APIRoute) -> str:
)

app.include_router(api_router, prefix=settings.API_V1_STR)
app.include_router(mcp_ws_router, prefix="/mcp", tags=["mcp-websocket"])
17 changes: 17 additions & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,16 @@
Message,
)

from app.models.mcp_server import (
MCPServer,
MCPServerBase,
MCPServerCreate,
MCPServerUpdate,
MCPServerPublic,
MCPTransportType,
MCPServerStatus,
)

# This ensures that SQLModel knows about all models for migrations
__all__ = [
'BaseDBModel',
@@ -68,4 +78,11 @@
'ItemPublic',
'ItemsPublic',
'Message',
'MCPServer',
'MCPServerBase',
'MCPServerCreate',
'MCPServerUpdate',
'MCPServerPublic',
'MCPTransportType',
'MCPServerStatus',
]
78 changes: 78 additions & 0 deletions backend/app/models/mcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""MCP Server model for storing server configurations."""
import enum
from typing import Optional
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, SQLModel
from sqlalchemy import Column
from sqlalchemy.dialects.postgresql import JSON
from app.models.base import BaseDBModel


class MCPTransportType(str, enum.Enum):
"""MCP transport protocol types."""
STDIO = "stdio"
HTTP_SSE = "http_sse"


class MCPServerStatus(str, enum.Enum):
"""MCP server connection status."""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
ERROR = "error"


class MCPServerBase(SQLModel):
"""Base MCP server configuration."""
name: str = Field(index=True, unique=True)
description: Optional[str] = None
transport: MCPTransportType = Field(default=MCPTransportType.STDIO)
command: Optional[str] = Field(default=None, description="Command to execute for stdio transport")
args: Optional[list[str]] = Field(default=None, sa_column=Column(JSON))
url: Optional[str] = Field(default=None, description="URL for HTTP+SSE transport")
headers: Optional[dict[str, str]] = Field(default=None, sa_column=Column(JSON))
is_enabled: bool = Field(default=True)
is_remote: bool = Field(default=False)


class MCPServer(BaseDBModel, MCPServerBase, table=True):
"""MCP server configuration stored in database."""
__tablename__ = "mcp_servers"

# Server capabilities (discovered after connection)
capabilities: Optional[dict] = Field(default=None, sa_column=Column(JSON))
tools: Optional[list[dict]] = Field(default=None, sa_column=Column(JSON))
resources: Optional[list[dict]] = Field(default=None, sa_column=Column(JSON))
prompts: Optional[list[dict]] = Field(default=None, sa_column=Column(JSON))


class MCPServerCreate(MCPServerBase):
"""Schema for creating MCP server."""
pass


class MCPServerUpdate(SQLModel):
"""Schema for updating MCP server."""
name: Optional[str] = None
description: Optional[str] = None
transport: Optional[MCPTransportType] = None
command: Optional[str] = None
args: Optional[list[str]] = None
url: Optional[str] = None
headers: Optional[dict[str, str]] = None
is_enabled: Optional[bool] = None
is_remote: Optional[bool] = None


class MCPServerPublic(MCPServerBase):
"""Public schema for MCP server."""
id: UUID
status: MCPServerStatus = MCPServerStatus.DISCONNECTED
error_message: Optional[str] = None
capabilities: Optional[dict] = None
tools: Optional[list[dict]] = None
resources: Optional[list[dict]] = None
prompts: Optional[list[dict]] = None
created_at: datetime
updated_at: datetime
Empty file.
337 changes: 337 additions & 0 deletions backend/app/services/mcp_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
"""MCP Server Manager for handling MCP server connections."""
import asyncio
import json
import logging
from typing import Dict, Optional, Any
from datetime import datetime
import subprocess
import aiohttp
from app.models.mcp_server import MCPServer, MCPServerStatus, MCPTransportType
from app.core.config import settings

logger = logging.getLogger(__name__)


class MCPConnection:
"""Represents a connection to an MCP server."""

def __init__(self, server: MCPServer):
self.server = server
self.process: Optional[subprocess.Popen] = None
self.session: Optional[aiohttp.ClientSession] = None
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self._read_task: Optional[asyncio.Task] = None
self._message_id = 0
self._pending_requests: Dict[int, asyncio.Future] = {}
# Runtime status
self.status: MCPServerStatus = MCPServerStatus.DISCONNECTED
self.error_message: Optional[str] = None

async def connect(self) -> bool:
"""Connect to the MCP server."""
try:
if self.server.transport == MCPTransportType.STDIO:
return await self._connect_stdio()
elif self.server.transport == MCPTransportType.HTTP_SSE:
return await self._connect_http_sse()
else:
raise ValueError(f"Unsupported transport: {self.server.transport}")
except Exception as e:
logger.error(f"Failed to connect to MCP server {self.server.name}: {e}")
self.status = MCPServerStatus.ERROR
self.error_message = str(e)
return False

async def _connect_stdio(self) -> bool:
"""Connect to MCP server via stdio."""
if not self.server.command:
raise ValueError("Command is required for stdio transport")

args = [self.server.command]
if self.server.args:
args.extend(self.server.args)

# Start the process
self.process = await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)

self.reader = self.process.stdout
self.writer = self.process.stdin

# Start reading messages
self._read_task = asyncio.create_task(self._read_messages())

# Send initialize request
response = await self._send_request("initialize", {
"protocolVersion": "0.1.0",
"capabilities": {
"roots": True,
"tools": True,
"prompts": True,
"resources": True
},
"clientInfo": {
"name": "copilot-mcp-host",
"version": "0.1.0"
}
})

if response and "capabilities" in response:
self.server.capabilities = response["capabilities"]
self.status = MCPServerStatus.CONNECTED

# Fetch available tools, resources, and prompts
await self._fetch_server_features()
return True

return False

async def _connect_http_sse(self) -> bool:
"""Connect to MCP server via HTTP+SSE."""
if not self.server.url:
raise ValueError("URL is required for HTTP+SSE transport")

self.session = aiohttp.ClientSession(headers=self.server.headers or {})

# Send initialize request
response = await self._send_http_request("initialize", {
"protocolVersion": "0.1.0",
"capabilities": {
"roots": True,
"tools": True,
"prompts": True,
"resources": True
},
"clientInfo": {
"name": "copilot-mcp-host",
"version": "0.1.0"
}
})

if response and "capabilities" in response:
self.server.capabilities = response["capabilities"]
self.status = MCPServerStatus.CONNECTED

# Fetch available tools, resources, and prompts
await self._fetch_server_features()
return True

return False

async def _fetch_server_features(self):
"""Fetch available tools, resources, and prompts from the server."""
# Fetch tools
if self.server.capabilities.get("tools"):
tools_response = await self._send_request("tools/list", {})
if tools_response and "tools" in tools_response:
self.server.tools = tools_response["tools"]

# Fetch resources
if self.server.capabilities.get("resources"):
resources_response = await self._send_request("resources/list", {})
if resources_response and "resources" in resources_response:
self.server.resources = resources_response["resources"]

# Fetch prompts
if self.server.capabilities.get("prompts"):
prompts_response = await self._send_request("prompts/list", {})
if prompts_response and "prompts" in prompts_response:
self.server.prompts = prompts_response["prompts"]

async def _send_request(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Send a JSON-RPC request to the server."""
if self.server.transport == MCPTransportType.STDIO:
return await self._send_stdio_request(method, params)
elif self.server.transport == MCPTransportType.HTTP_SSE:
return await self._send_http_request(method, params)

async def _send_stdio_request(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Send a request via stdio."""
if not self.writer:
raise RuntimeError("Not connected")

self._message_id += 1
message = {
"jsonrpc": "2.0",
"id": self._message_id,
"method": method,
"params": params
}

# Create a future for this request
future = asyncio.Future()
self._pending_requests[self._message_id] = future

# Send the message
message_str = json.dumps(message) + "\n"
self.writer.write(message_str.encode())
await self.writer.drain()

# Wait for response
try:
response = await asyncio.wait_for(future, timeout=30.0)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(self._message_id, None)
raise

async def _send_http_request(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Send a request via HTTP."""
if not self.session:
raise RuntimeError("Not connected")

self._message_id += 1
message = {
"jsonrpc": "2.0",
"id": self._message_id,
"method": method,
"params": params
}

async with self.session.post(self.server.url, json=message) as response:
if response.status == 200:
return await response.json()
else:
raise RuntimeError(f"HTTP error: {response.status}")

async def _read_messages(self):
"""Read messages from stdio."""
buffer = ""
while self.reader and not self.reader.at_eof():
try:
data = await self.reader.read(1024)
if not data:
break

buffer += data.decode()

# Process complete messages
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line.strip():
try:
message = json.loads(line)
await self._handle_message(message)
except json.JSONDecodeError:
logger.error(f"Invalid JSON: {line}")

except Exception as e:
logger.error(f"Error reading messages: {e}")
break

async def _handle_message(self, message: Dict[str, Any]):
"""Handle an incoming message."""
if "id" in message and message["id"] in self._pending_requests:
# This is a response to our request
future = self._pending_requests.pop(message["id"])
if "result" in message:
future.set_result(message["result"])
elif "error" in message:
future.set_exception(RuntimeError(message["error"]))

async def disconnect(self):
"""Disconnect from the MCP server."""
try:
if self._read_task:
self._read_task.cancel()

if self.writer:
self.writer.close()
await self.writer.wait_closed()

if self.process:
self.process.terminate()
await self.process.wait()

if self.session:
await self.session.close()

self.status = MCPServerStatus.DISCONNECTED

except Exception as e:
logger.error(f"Error disconnecting from MCP server {self.server.name}: {e}")

async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Call a tool on the MCP server."""
response = await self._send_request("tools/call", {
"name": tool_name,
"arguments": arguments
})
return response

async def get_resource(self, uri: str) -> Any:
"""Get a resource from the MCP server."""
response = await self._send_request("resources/read", {
"uri": uri
})
return response

async def get_prompt(self, name: str, arguments: Dict[str, Any]) -> Any:
"""Get a prompt from the MCP server."""
response = await self._send_request("prompts/get", {
"name": name,
"arguments": arguments
})
return response


class MCPServerManager:
"""Manager for MCP server connections."""

def __init__(self):
self.connections: Dict[str, MCPConnection] = {}

async def connect_server(self, server: MCPServer) -> bool:
"""Connect to an MCP server."""
if server.name in self.connections:
await self.disconnect_server(server.name)

# Set connecting status
connection = MCPConnection(server)

if await connection.connect():
self.connections[server.name] = connection
return True
else:
return False

async def disconnect_server(self, server_name: str):
"""Disconnect from an MCP server."""
if server_name in self.connections:
connection = self.connections.pop(server_name)
await connection.disconnect()

async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""Call a tool on a specific MCP server."""
if server_name not in self.connections:
raise RuntimeError(f"Server {server_name} not connected")

return await self.connections[server_name].call_tool(tool_name, arguments)

async def get_resource(self, server_name: str, uri: str) -> Any:
"""Get a resource from a specific MCP server."""
if server_name not in self.connections:
raise RuntimeError(f"Server {server_name} not connected")

return await self.connections[server_name].get_resource(uri)

async def get_prompt(self, server_name: str, name: str, arguments: Dict[str, Any]) -> Any:
"""Get a prompt from a specific MCP server."""
if server_name not in self.connections:
raise RuntimeError(f"Server {server_name} not connected")

return await self.connections[server_name].get_prompt(name, arguments)

async def disconnect_all(self):
"""Disconnect from all MCP servers."""
for server_name in list(self.connections.keys()):
await self.disconnect_server(server_name)


# Global instance
mcp_manager = MCPServerManager()
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ dependencies = [
"pydantic-settings<3.0.0,>=2.2.1",
"sentry-sdk[fastapi]<2.0.0,>=1.40.6",
"pyjwt<3.0.0,>=2.8.0",
"aiohttp<4.0.0,>=3.9.0",
]

[tool.uv]