revert(server): use direct DB queries for all auth validation

Reverts middleware and dashboard WS to direct DB session lookups.
Replaces auth.api.verifyApiKey in device WS with direct DB query
using SHA-256 hash matching, removing dependency on BETTER_AUTH_SECRET
for auth validation.
This commit is contained in:
Sanju Sivalingam
2026-02-18 11:46:48 +05:30
parent a865c1e2f0
commit 68ca812267
3 changed files with 86 additions and 21 deletions

View File

@@ -1,5 +1,8 @@
import type { Context, Next } from "hono"; import type { Context, Next } from "hono";
import { auth } from "../auth.js"; import { db } from "../db.js";
import { session as sessionTable, user as userTable } from "../schema.js";
import { eq } from "drizzle-orm";
import { getCookie } from "hono/cookie";
/** Hono Env type for routes protected by sessionMiddleware */ /** Hono Env type for routes protected by sessionMiddleware */
export type AuthEnv = { export type AuthEnv = {
@@ -10,13 +13,43 @@ export type AuthEnv = {
}; };
export async function sessionMiddleware(c: Context, next: Next) { export async function sessionMiddleware(c: Context, next: Next) {
const session = await auth.api.getSession({ headers: c.req.raw.headers }); // Extract session token from cookie (same approach as dashboard WS auth)
const rawCookie = getCookie(c, "better-auth.session_token");
if (!session) { if (!rawCookie) {
return c.json({ error: "unauthorized" }, 401); return c.json({ error: "unauthorized" }, 401);
} }
c.set("user", session.user); // Token may have a signature appended after a dot — use only the token part
c.set("session", session.session); const token = rawCookie.split(".")[0];
// Direct DB lookup (proven to work, unlike auth.api.getSession)
const rows = await db
.select({
sessionId: sessionTable.id,
userId: sessionTable.userId,
})
.from(sessionTable)
.where(eq(sessionTable.token, token))
.limit(1);
if (rows.length === 0) {
return c.json({ error: "unauthorized" }, 401);
}
const { sessionId, userId } = rows[0];
// Fetch user info
const users = await db
.select({ id: userTable.id, name: userTable.name, email: userTable.email })
.from(userTable)
.where(eq(userTable.id, userId))
.limit(1);
if (users.length === 0) {
return c.json({ error: "unauthorized" }, 401);
}
c.set("user", users[0]);
c.set("session", { id: sessionId, userId });
await next(); await next();
} }

View File

@@ -1,5 +1,7 @@
import type { ServerWebSocket } from "bun"; import type { ServerWebSocket } from "bun";
import { auth } from "../auth.js"; import { db } from "../db.js";
import { session as sessionTable } from "../schema.js";
import { eq } from "drizzle-orm";
import { sessions, type WebSocketData } from "./sessions.js"; import { sessions, type WebSocketData } from "./sessions.js";
interface DashboardAuthMessage { interface DashboardAuthMessage {
@@ -34,17 +36,19 @@ export async function handleDashboardMessage(
return; return;
} }
// Validate session via better-auth // Look up session directly in DB
const session = await auth.api.getSession({ const rows = await db
headers: new Headers({ cookie: `better-auth.session_token=${token}` }), .select({ userId: sessionTable.userId })
}); .from(sessionTable)
.where(eq(sessionTable.token, token))
.limit(1);
if (!session) { if (rows.length === 0) {
ws.send(JSON.stringify({ type: "auth_error", message: "Invalid session" })); ws.send(JSON.stringify({ type: "auth_error", message: "Invalid session" }));
return; return;
} }
const userId = session.user.id; const userId = rows[0].userId;
// Mark connection as authenticated // Mark connection as authenticated
ws.data.authenticated = true; ws.data.authenticated = true;

View File

@@ -1,13 +1,26 @@
import type { ServerWebSocket } from "bun"; import type { ServerWebSocket } from "bun";
import type { DeviceMessage } from "@droidclaw/shared"; import type { DeviceMessage } from "@droidclaw/shared";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { auth } from "../auth.js";
import { db } from "../db.js"; import { db } from "../db.js";
import { llmConfig, device } from "../schema.js"; import { apikey, llmConfig, device } from "../schema.js";
import { sessions, type WebSocketData } from "./sessions.js"; import { sessions, type WebSocketData } from "./sessions.js";
import { runPipeline } from "../agent/pipeline.js"; import { runPipeline } from "../agent/pipeline.js";
import type { LLMConfig } from "../agent/llm.js"; import type { LLMConfig } from "../agent/llm.js";
/**
* Hash an API key the same way better-auth does:
* SHA-256 → base64url (no padding).
*/
async function hashApiKey(key: string): Promise<string> {
const data = new TextEncoder().encode(key);
const hash = await crypto.subtle.digest("SHA-256", data);
// base64url encode without padding
return btoa(String.fromCharCode(...new Uint8Array(hash)))
.replace(/\+/g, "-")
.replace(/\//g, "_")
.replace(/=+$/, "");
}
/** Track running agent sessions to prevent duplicates per device */ /** Track running agent sessions to prevent duplicates per device */
const activeSessions = new Map<string, string>(); const activeSessions = new Map<string, string>();
@@ -77,22 +90,37 @@ export async function handleDeviceMessage(
if (msg.type === "auth") { if (msg.type === "auth") {
try { try {
const result = await auth.api.verifyApiKey({ // Hash the incoming key and look it up directly in the DB
body: { key: msg.apiKey }, const hashedKey = await hashApiKey(msg.apiKey);
}); const rows = await db
.select({ id: apikey.id, userId: apikey.userId, enabled: apikey.enabled, expiresAt: apikey.expiresAt })
.from(apikey)
.where(eq(apikey.key, hashedKey))
.limit(1);
if (!result.valid || !result.key) { if (rows.length === 0 || !rows[0].enabled) {
ws.send( ws.send(
JSON.stringify({ JSON.stringify({
type: "auth_error", type: "auth_error",
message: result.error?.message ?? "Invalid API key", message: "Invalid API key",
})
);
return;
}
// Check expiration
if (rows[0].expiresAt && rows[0].expiresAt < new Date()) {
ws.send(
JSON.stringify({
type: "auth_error",
message: "API key expired",
}) })
); );
return; return;
} }
const deviceId = crypto.randomUUID(); const deviceId = crypto.randomUUID();
const userId = result.key.userId; const userId = rows[0].userId;
// Build device name from device info // Build device name from device info
const name = msg.deviceInfo const name = msg.deviceInfo