diff --git a/common/TokenClaims.ts b/common/TokenClaims.ts index eaf7479..81d7a31 100644 --- a/common/TokenClaims.ts +++ b/common/TokenClaims.ts @@ -19,4 +19,9 @@ export interface DeviceRegistrationToken extends BaseClaims { type: "device_reg"; } -export type TokenClaims = AccessToken | RefreshToken | DeviceRegistrationToken; +export interface DeviceToken extends BaseClaims { + type: "device"; + aud: string; +} + +export type TokenClaims = AccessToken | RefreshToken | DeviceRegistrationToken | DeviceToken; diff --git a/server/express/api/devices.ts b/server/express/api/devices.ts index 8e56827..fc7a370 100644 --- a/server/express/api/devices.ts +++ b/server/express/api/devices.ts @@ -4,10 +4,26 @@ import { serialize} from "serializr"; import ApiError from "@common/ApiError"; import { ErrorCode } from "@common/ErrorCode"; import * as schema from "@common/sprinklersRpc/schema"; -import { AccessToken } from "@common/TokenClaims"; -import { verifyAuthorization } from "@server/express/authentication"; +import { generateDeviceToken, verifyAuthorization } from "@server/express/authentication"; import { ServerState } from "@server/state"; +const DEVICE_ID_LEN = 20; + +function randomDeviceId(): string { + let deviceId = ""; + for (let i = 0; i < DEVICE_ID_LEN; i++) { + const j = Math.floor(Math.random() * 36); + let ch; // tslint:disable-next-line + if (j < 10) { // 0-9 + ch = String.fromCharCode(48 + j); + } else { // a-z + ch = String.fromCharCode(97 + (j - 10)); + } + deviceId += ch; + } + return deviceId; +} + export function devices(state: ServerState) { const router = PromiseRouter(); @@ -28,7 +44,23 @@ export function devices(state: ServerState) { router.post("/register", verifyAuthorization({ type: "device_reg", }), async (req, res) => { - // TODO: Implement device registration + const deviceId = randomDeviceId(); + const newDevice = state.database.sprinklersDevices.create({ + name: "Sprinklers Device", deviceId, + }); + await state.database.sprinklersDevices.save(newDevice); + const token = await generateDeviceToken(deviceId); + res.send({ + data: newDevice, token, + }); + }); + + router.post("/connect", verifyAuthorization({ + type: "device", + }), async (req, res) => { + res.send({ + url: state.mqttUrl, + }); }); return router; diff --git a/server/express/api/mosquitto.ts b/server/express/api/mosquitto.ts index 41f7113..e557406 100644 --- a/server/express/api/mosquitto.ts +++ b/server/express/api/mosquitto.ts @@ -16,6 +16,6 @@ export function mosquitto(state: ServerState) { router.post("/acl", async (req, res) => { res.status(200).send(); }); - + return router; -} \ No newline at end of file +} diff --git a/server/express/authentication.ts b/server/express/authentication.ts index b7869a9..5517f70 100644 --- a/server/express/authentication.ts +++ b/server/express/authentication.ts @@ -10,7 +10,7 @@ import { TokenGrantRequest, TokenGrantResponse, } from "@common/httpApi"; -import { AccessToken, DeviceRegistrationToken, RefreshToken, TokenClaims } from "@common/TokenClaims"; +import { AccessToken, DeviceRegistrationToken, DeviceToken, RefreshToken, TokenClaims } from "@common/TokenClaims"; import { User } from "../entities"; import { ServerState } from "../state"; @@ -69,7 +69,8 @@ export function verifyToken( } else { const claims: TokenClaims = decoded as any; if (type != null && claims.type !== type) { - reject(new ApiError(`Expected a "${type} token, received a "${claims.type}" token`)); + reject(new ApiError(`Expected a "${type}" token, received a "${claims.type}" token`, + ErrorCode.BadToken)); } resolve(claims as TClaims); } @@ -109,6 +110,15 @@ function generateDeviceRegistrationToken(secret: string): Promise { return signToken(device_reg_token_claims); } +export function generateDeviceToken(deviceId: string): Promise { + const device_token_claims: DeviceToken = { + iss: ISSUER, + type: "device", + aud: deviceId, + }; + return signToken(device_token_claims); +} + export function authentication(state: ServerState) { const router = Router(); @@ -158,7 +168,7 @@ export function authentication(state: ServerState) { } const [access_token, refresh_token] = await Promise.all( [await generateAccessToken(user, JWT_SECRET), - await generateRefreshToken(user, JWT_SECRET)]); + await generateRefreshToken(user, JWT_SECRET)]); const response: TokenGrantResponse = { access_token, refresh_token, }; @@ -188,7 +198,7 @@ export function verifyAuthorization(options?: Partial): const opts: VerifyAuthorizationOpts = { type: "access", ...options, - }; + }; return (req, res, next) => { const fun = async () => { const bearer = req.headers.authorization; @@ -201,12 +211,7 @@ export function verifyAuthorization(options?: Partial): } const token = matches[1]; - req.token = await verifyToken(token, "access"); - - if (req.token.type !== opts.type) { - throw new ApiError(`Invalid token type "${req.token.type}", must be "${opts.type}"`, - ErrorCode.BadToken); - } + req.token = await verifyToken(token, opts.type) as any; }; fun().then(() => next(null), (err) => next(err)); }; diff --git a/server/state.ts b/server/state.ts index 0604644..3a0dc6c 100644 --- a/server/state.ts +++ b/server/state.ts @@ -3,6 +3,7 @@ import * as mqtt from "@common/sprinklersRpc/mqtt"; import { Database } from "./Database"; export class ServerState { + mqttUrl: string; mqttClient: mqtt.MqttRpcClient; database: Database; @@ -11,6 +12,7 @@ export class ServerState { if (!mqttUrl) { throw new Error("Must specify a MQTT_URL to connect to"); } + this.mqttUrl = mqttUrl; this.mqttClient = new mqtt.MqttRpcClient(mqttUrl); this.database = new Database(); }