diff --git a/server/src/index.ts b/server/src/index.ts index 1912e40..fcd6cb8 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -2,6 +2,12 @@ import { Hono } from "hono"; import { cors } from "hono/cors"; import { auth } from "./auth.js"; import { env } from "./env.js"; +import { handleDeviceMessage, handleDeviceClose } from "./ws/device.js"; +import { + handleDashboardMessage, + handleDashboardClose, +} from "./ws/dashboard.js"; +import type { WebSocketData } from "./ws/sessions.js"; const app = new Hono(); @@ -25,18 +31,54 @@ app.on(["POST", "GET"], "/api/auth/*", (c) => { app.get("/health", (c) => c.json({ status: "ok" })); // Start server with WebSocket support -const server = Bun.serve({ +const server = Bun.serve({ port: env.PORT, - fetch: app.fetch, + fetch(req, server) { + const url = new URL(req.url); + + // WebSocket upgrade for device connections + if (url.pathname === "/ws/device") { + const upgraded = server.upgrade(req, { + data: { path: "/ws/device" as const, authenticated: false }, + }); + if (upgraded) return undefined; + return new Response("WebSocket upgrade failed", { status: 400 }); + } + + // WebSocket upgrade for dashboard connections + if (url.pathname === "/ws/dashboard") { + const upgraded = server.upgrade(req, { + data: { path: "/ws/dashboard" as const, authenticated: false }, + }); + if (upgraded) return undefined; + return new Response("WebSocket upgrade failed", { status: 400 }); + } + + // Non-WebSocket requests go to Hono + return app.fetch(req); + }, websocket: { open(ws) { - console.log("WebSocket connected"); + console.log(`WebSocket opened: ${ws.data.path}`); }, message(ws, message) { - // placeholder — Task 4 implements device/dashboard handlers + const raw = + typeof message === "string" + ? message + : new TextDecoder().decode(message); + + if (ws.data.path === "/ws/device") { + handleDeviceMessage(ws, raw); + } else if (ws.data.path === "/ws/dashboard") { + handleDashboardMessage(ws, raw); + } }, close(ws) { - console.log("WebSocket disconnected"); + if (ws.data.path === "/ws/device") { + handleDeviceClose(ws); + } else if (ws.data.path === "/ws/dashboard") { + handleDashboardClose(ws); + } }, }, }); diff --git a/server/src/ws/dashboard.ts b/server/src/ws/dashboard.ts new file mode 100644 index 0000000..b147603 --- /dev/null +++ b/server/src/ws/dashboard.ts @@ -0,0 +1,110 @@ +import type { ServerWebSocket } from "bun"; +import { auth } from "../auth.js"; +import { sessions, type WebSocketData } from "./sessions.js"; + +interface DashboardAuthMessage { + type: "auth"; + token: string; +} + +type DashboardIncomingMessage = DashboardAuthMessage; + +/** + * Handle an incoming message from a dashboard WebSocket. + */ +export async function handleDashboardMessage( + ws: ServerWebSocket, + raw: string +): Promise { + let msg: DashboardIncomingMessage; + try { + msg = JSON.parse(raw) as DashboardIncomingMessage; + } catch { + ws.send(JSON.stringify({ type: "error", message: "Invalid JSON" })); + return; + } + + // ── Authentication ───────────────────────────────────── + + if (msg.type === "auth") { + try { + // Verify the session token by constructing a request with the cookie header + const sessionResult = await auth.api.getSession({ + headers: new Headers({ + cookie: `better-auth.session_token=${msg.token}`, + }), + }); + + if (!sessionResult) { + ws.send( + JSON.stringify({ type: "auth_error", message: "Invalid session" }) + ); + return; + } + + const userId = sessionResult.user.id; + + // Mark connection as authenticated + ws.data.authenticated = true; + ws.data.userId = userId; + + // Register as dashboard subscriber + sessions.addDashboardSubscriber({ userId, ws }); + + // Send auth confirmation + ws.send(JSON.stringify({ type: "auth_ok" })); + + // Send current device list for this user + const devices = sessions.getDevicesForUser(userId); + for (const device of devices) { + const name = device.deviceInfo + ? `${device.deviceInfo.model} (Android ${device.deviceInfo.androidVersion})` + : device.deviceId; + + ws.send( + JSON.stringify({ + type: "device_online", + deviceId: device.deviceId, + name, + }) + ); + } + + console.log(`Dashboard subscriber authenticated: user ${userId}`); + } catch (err) { + ws.send( + JSON.stringify({ + type: "auth_error", + message: "Authentication failed", + }) + ); + console.error("Dashboard auth error:", err); + } + return; + } + + // ── All other messages require authentication ───────── + + if (!ws.data.authenticated) { + ws.send( + JSON.stringify({ type: "error", message: "Not authenticated" }) + ); + return; + } + + // Future: handle dashboard commands (e.g., send goal to device) + console.warn( + `Unknown message type from dashboard:`, + (msg as unknown as Record).type + ); +} + +/** + * Handle a dashboard WebSocket disconnection. + */ +export function handleDashboardClose( + ws: ServerWebSocket +): void { + sessions.removeDashboardSubscriber(ws); + console.log(`Dashboard subscriber disconnected: user ${ws.data.userId ?? "unknown"}`); +} diff --git a/server/src/ws/device.ts b/server/src/ws/device.ts new file mode 100644 index 0000000..ace6267 --- /dev/null +++ b/server/src/ws/device.ts @@ -0,0 +1,157 @@ +import type { ServerWebSocket } from "bun"; +import type { DeviceMessage } from "@droidclaw/shared"; +import { auth } from "../auth.js"; +import { sessions, type WebSocketData } from "./sessions.js"; + +/** + * Handle an incoming message from an Android device WebSocket. + */ +export async function handleDeviceMessage( + ws: ServerWebSocket, + raw: string +): Promise { + let msg: DeviceMessage; + try { + msg = JSON.parse(raw) as DeviceMessage; + } catch { + ws.send(JSON.stringify({ type: "error", message: "Invalid JSON" })); + return; + } + + // ── Authentication ───────────────────────────────────── + + if (msg.type === "auth") { + try { + const result = await auth.api.verifyApiKey({ + body: { key: msg.apiKey }, + }); + + if (!result.valid || !result.key) { + ws.send( + JSON.stringify({ + type: "auth_error", + message: result.error?.message ?? "Invalid API key", + }) + ); + return; + } + + const deviceId = crypto.randomUUID(); + const userId = result.key.userId; + + // Mark connection as authenticated + ws.data.authenticated = true; + ws.data.userId = userId; + ws.data.deviceId = deviceId; + + // Register device in session manager + sessions.addDevice({ + deviceId, + userId, + ws, + deviceInfo: msg.deviceInfo, + connectedAt: new Date(), + }); + + // Confirm auth to the device + ws.send(JSON.stringify({ type: "auth_ok", deviceId })); + + // Notify dashboard subscribers + const name = msg.deviceInfo + ? `${msg.deviceInfo.model} (Android ${msg.deviceInfo.androidVersion})` + : deviceId; + + sessions.notifyDashboard(userId, { + type: "device_online", + deviceId, + name, + }); + + console.log(`Device authenticated: ${deviceId} for user ${userId}`); + } catch (err) { + ws.send( + JSON.stringify({ + type: "auth_error", + message: "Authentication failed", + }) + ); + console.error("Device auth error:", err); + } + return; + } + + // ── All other messages require authentication ───────── + + if (!ws.data.authenticated) { + ws.send( + JSON.stringify({ type: "error", message: "Not authenticated" }) + ); + return; + } + + switch (msg.type) { + case "screen": { + // Device is reporting its screen state in response to a get_screen command + sessions.resolveRequest(msg.requestId, { + type: "screen", + elements: msg.elements, + screenshot: msg.screenshot, + packageName: msg.packageName, + }); + break; + } + + case "result": { + // Device is reporting the result of an action command + sessions.resolveRequest(msg.requestId, { + type: "result", + success: msg.success, + error: msg.error, + data: msg.data, + }); + break; + } + + case "goal": { + // Device is requesting a goal to be executed + // Task 6 wires up the agent loop here + console.log( + `Goal request from device ${ws.data.deviceId}: ${msg.text}` + ); + break; + } + + case "pong": { + // Heartbeat response — no-op + break; + } + + default: { + console.warn( + `Unknown message type from device ${ws.data.deviceId}:`, + (msg as Record).type + ); + } + } +} + +/** + * Handle a device WebSocket disconnection. + */ +export function handleDeviceClose( + ws: ServerWebSocket +): void { + const { deviceId, userId } = ws.data; + if (!deviceId) return; + + sessions.removeDevice(deviceId); + + if (userId) { + sessions.notifyDashboard(userId, { + type: "device_offline", + deviceId, + }); + } + + console.log(`Device disconnected: ${deviceId}`); +} diff --git a/server/src/ws/sessions.ts b/server/src/ws/sessions.ts new file mode 100644 index 0000000..65f2eaf --- /dev/null +++ b/server/src/ws/sessions.ts @@ -0,0 +1,162 @@ +import type { ServerWebSocket } from "bun"; +import type { DeviceInfo, DashboardMessage } from "@droidclaw/shared"; + +/** Data attached to each WebSocket connection by Bun.serve upgrade */ +export interface WebSocketData { + path: "/ws/device" | "/ws/dashboard"; + userId?: string; + deviceId?: string; + authenticated: boolean; +} + +/** A connected Android device */ +export interface ConnectedDevice { + deviceId: string; + userId: string; + ws: ServerWebSocket; + deviceInfo?: DeviceInfo; + connectedAt: Date; +} + +/** A dashboard client subscribed to real-time updates */ +export interface DashboardSubscriber { + userId: string; + ws: ServerWebSocket; +} + +/** A pending request waiting for a device response */ +export interface PendingRequest { + resolve: (data: unknown) => void; + reject: (error: Error) => void; + timer: ReturnType; +} + +const DEFAULT_COMMAND_TIMEOUT = 30_000; // 30 seconds + +class SessionManager { + private devices = new Map(); + private dashboardSubscribers = new Set(); + private pendingRequests = new Map(); + + // ── Device management ────────────────────────────────── + + addDevice(device: ConnectedDevice): void { + this.devices.set(device.deviceId, device); + } + + removeDevice(deviceId: string): void { + this.devices.delete(deviceId); + // Note: pending requests for this device will time out naturally + // since we can't map requestId → deviceId without extra bookkeeping. + } + + getDevice(deviceId: string): ConnectedDevice | undefined { + return this.devices.get(deviceId); + } + + getDevicesForUser(userId: string): ConnectedDevice[] { + const result: ConnectedDevice[] = []; + for (const device of this.devices.values()) { + if (device.userId === userId) { + result.push(device); + } + } + return result; + } + + getAllDevices(): ConnectedDevice[] { + return Array.from(this.devices.values()); + } + + // ── Dashboard subscriber management ─────────────────── + + addDashboardSubscriber(sub: DashboardSubscriber): void { + this.dashboardSubscribers.add(sub); + } + + removeDashboardSubscriber(ws: ServerWebSocket): void { + for (const sub of this.dashboardSubscribers) { + if (sub.ws === ws) { + this.dashboardSubscribers.delete(sub); + break; + } + } + } + + /** Send a JSON message to all dashboard subscribers for a given user */ + notifyDashboard(userId: string, message: DashboardMessage): void { + const payload = JSON.stringify(message); + for (const sub of this.dashboardSubscribers) { + if (sub.userId === userId) { + try { + sub.ws.send(payload); + } catch { + // subscriber disconnected; will be cleaned up on close + } + } + } + } + + // ── Request/response pattern for device commands ────── + + /** + * Send a command to a device and wait for its response. + * Returns a Promise that resolves when the device sends back + * a message with a matching requestId. + */ + sendCommand( + deviceId: string, + command: Record, + timeout = DEFAULT_COMMAND_TIMEOUT + ): Promise { + const device = this.devices.get(deviceId); + if (!device) { + return Promise.reject(new Error(`Device ${deviceId} not connected`)); + } + + const requestId = + command.requestId as string | undefined ?? + crypto.randomUUID(); + + const commandWithId = { ...command, requestId }; + + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + this.pendingRequests.delete(requestId); + reject(new Error(`Command timed out after ${timeout}ms`)); + }, timeout); + + this.pendingRequests.set(requestId, { resolve, reject, timer }); + + try { + device.ws.send(JSON.stringify(commandWithId)); + } catch (err) { + clearTimeout(timer); + this.pendingRequests.delete(requestId); + reject(new Error(`Failed to send command to device: ${err}`)); + } + }); + } + + /** Resolve a pending request when a device responds */ + resolveRequest(requestId: string, data: unknown): boolean { + const pending = this.pendingRequests.get(requestId); + if (!pending) return false; + + clearTimeout(pending.timer); + this.pendingRequests.delete(requestId); + pending.resolve(data); + return true; + } + + /** Get counts for monitoring */ + getStats() { + return { + devices: this.devices.size, + dashboardSubscribers: this.dashboardSubscribers.size, + pendingRequests: this.pendingRequests.size, + }; + } +} + +export const sessions = new SessionManager(); diff --git a/server/tsconfig.json b/server/tsconfig.json index ab48925..18fb4a4 100644 --- a/server/tsconfig.json +++ b/server/tsconfig.json @@ -7,7 +7,6 @@ "esModuleInterop": true, "skipLibCheck": true, "outDir": "dist", - "rootDir": "src", "types": ["bun"], "paths": { "@droidclaw/shared": ["../packages/shared/src"]