From 929622603650d7d2254809a8f9e13062099f21c9 Mon Sep 17 00:00:00 2001 From: Alex Mikhalev Date: Wed, 29 Aug 2018 01:31:10 -0600 Subject: [PATCH] Properly subscribe to and unsubscribe from devices --- Dockerfile | 1 - client/components/DeviceView.tsx | 61 +++-- client/pages/LogoutPage.tsx | 2 +- client/pages/ProgramPage.tsx | 92 ++++--- client/sprinklersRpc/WSSprinklersDevice.ts | 26 +- client/sprinklersRpc/WebSocketRpcClient.ts | 50 ++-- common/TypedEventEmitter.ts | 16 ++ common/jsonRpc/index.ts | 15 +- common/sprinklersRpc/RpcError.ts | 21 ++ common/sprinklersRpc/SprinklersDevice.ts | 33 ++- common/sprinklersRpc/SprinklersRPC.ts | 30 ++- common/sprinklersRpc/index.ts | 2 +- common/sprinklersRpc/mqtt/index.ts | 83 ++++--- common/sprinklersRpc/websocketData.ts | 22 +- common/utils.ts | 8 + server/express/api/devices.ts | 3 +- server/express/authentication.ts | 22 +- server/index.ts | 2 +- server/sprinklersRpc/WebSocketApi.ts | 26 ++ server/sprinklersRpc/WebSocketConnection.ts | 262 ++++++++++++++++++++ server/sprinklersRpc/websocketServer.ts | 247 ------------------ 21 files changed, 616 insertions(+), 408 deletions(-) create mode 100644 common/sprinklersRpc/RpcError.ts create mode 100644 server/sprinklersRpc/WebSocketApi.ts create mode 100644 server/sprinklersRpc/WebSocketConnection.ts delete mode 100644 server/sprinklersRpc/websocketServer.ts diff --git a/Dockerfile b/Dockerfile index 36517ab..a56733d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,5 +24,4 @@ COPY --from=builder /app/dist ./dist COPY --from=builder /app/public ./public EXPOSE 8080 -EXPOSE 8081 ENTRYPOINT [ "node", "." ] diff --git a/client/components/DeviceView.tsx b/client/components/DeviceView.tsx index 376fdd8..969bf91 100644 --- a/client/components/DeviceView.tsx +++ b/client/components/DeviceView.tsx @@ -52,10 +52,22 @@ interface DeviceViewProps { inList?: boolean; } -class DeviceView extends React.Component> { - renderBody(iDevice: ISprinklersDevice, device: SprinklersDevice) { +class DeviceView extends React.Component { + deviceInfo: ISprinklersDevice | null = null; + device: SprinklersDevice | null = null; + + componentWillUnmount() { + if (this.device) { + this.device.release(); + } + } + + renderBody() { const { inList, appState: { uiStore, routerStore } } = this.props; - const { connectionState, sectionRunner, sections } = device; + if (!this.deviceInfo || !this.device) { + return null; + } + const { connectionState, sectionRunner, sections } = this.device; if (!connectionState.isAvailable || inList) { return null; } @@ -69,30 +81,51 @@ class DeviceView extends React.Component - + - + ); } + updateDevice() { + const { userStore, sprinklersRpc } = this.props.appState; + const id = this.props.deviceId; + // tslint:disable-next-line:prefer-conditional-expression + if (this.deviceInfo == null || this.deviceInfo.id !== id) { + this.deviceInfo = userStore.findDevice(id); + } + if (!this.deviceInfo || !this.deviceInfo.deviceId) { + if (this.device) { + this.device.release(); + this.device = null; + } + } else { + if (this.device == null || this.device.id !== this.deviceInfo.deviceId) { + if (this.device) { + this.device.release(); + } + this.device = sprinklersRpc.acquireDevice(this.deviceInfo.deviceId); + } + } + } + render() { - const { deviceId, inList, appState: { sprinklersRpc, userStore } } = this.props; - const iDevice = userStore.findDevice(deviceId); + this.updateDevice(); + const { inList } = this.props; let itemContent: React.ReactNode; - if (!iDevice || !iDevice.deviceId) { + if (!this.deviceInfo || !this.device) { // TODO: better and link back to devices list itemContent = You do not have access to this device; } else { - const device = sprinklersRpc.getDevice(iDevice.deviceId); - const { connectionState } = device; + const { connectionState } = this.device; let header: React.ReactNode; if (inList) { // tslint:disable-line:prefer-conditional-expression - header = Device {iDevice.name}; + header = Device {this.deviceInfo.name}; } else { - header = Device {iDevice.name}; + header = Device {this.deviceInfo.name}; } itemContent = ( @@ -105,7 +138,7 @@ class DeviceView extends React.Component Raspberry Pi Grinklers Device - {this.renderBody(iDevice, device)} + {this.renderBody()} ); @@ -114,4 +147,4 @@ class DeviceView extends React.Component ); diff --git a/client/pages/ProgramPage.tsx b/client/pages/ProgramPage.tsx index 561846a..759beae 100644 --- a/client/pages/ProgramPage.tsx +++ b/client/pages/ProgramPage.tsx @@ -1,4 +1,4 @@ -import { assign, merge } from "lodash"; +import { assign } from "lodash"; import { observer } from "mobx-react"; import * as qs from "query-string"; import * as React from "react"; @@ -23,11 +23,60 @@ class ProgramPage extends React.Component { return qs.parse(this.props.location.search).editing != null; } - iDevice!: ISprinklersDevice; - device!: SprinklersDevice; - program!: Program; + deviceInfo: ISprinklersDevice | null = null; + device: SprinklersDevice | null = null; + program: Program | null = null; programView: Program | null = null; + componentWillUnmount() { + if (this.device) { + this.device.release(); + } + } + + updateProgram() { + const { userStore, sprinklersRpc } = this.props.appState; + const devId = Number(this.props.match.params.deviceId); + const programId = Number(this.props.match.params.programId); + // tslint:disable-next-line:prefer-conditional-expression + if (this.deviceInfo == null || this.deviceInfo.id !== devId) { + this.deviceInfo = userStore.findDevice(devId); + } + if (!this.deviceInfo || !this.deviceInfo.deviceId) { + if (this.device) { + this.device.release(); + this.device = null; + } + return; + } else { + if (this.device == null || this.device.id !== this.deviceInfo.deviceId) { + if (this.device) { + this.device.release(); + } + this.device = sprinklersRpc.acquireDevice(this.deviceInfo.deviceId); + } + } + if (!this.program || this.program.id !== programId) { + if (this.device.programs.length > programId && programId >= 0) { + this.program = this.device.programs[programId]; + } else { + return; + } + } + if (this.isEditing) { + if (this.programView == null && this.program) { + // this.programView = createViewModel(this.program); + // this.programView = observable(toJS(this.program)); + this.programView = this.program.clone(); + } + } else { + if (this.programView != null) { + // this.programView.reset(); + this.programView = null; + } + } + } + renderName(program: Program) { const { name } = program; if (this.isEditing) { @@ -98,39 +147,12 @@ class ProgramPage extends React.Component { } render() { - const { deviceId: did, programId: pid } = this.props.match.params; - const { userStore, sprinklersRpc } = this.props.appState; - const deviceId = Number(did); - const programId = Number(pid); - // tslint:disable-next-line:prefer-conditional-expression - if (!this.iDevice || this.iDevice.id !== deviceId) { - this.iDevice = userStore.findDevice(deviceId)!; - } - if (this.iDevice && this.iDevice.deviceId && (!this.device || this.device.id !== this.iDevice.deviceId)) { - this.device = sprinklersRpc.getDevice(this.iDevice.deviceId); - } - // tslint:disable-next-line:prefer-conditional-expression - if (!this.program || this.program.id !== programId) { - if (this.device.programs.length > programId && programId >= 0) { - this.program = this.device.programs[programId]; - } else { - return null; - } - } - if (this.isEditing) { - if (this.programView == null && this.program) { - // this.programView = createViewModel(this.program); - // this.programView = observable(toJS(this.program)); - this.programView = this.program.clone(); - } - } else { - if (this.programView != null) { - // this.programView.reset(); - this.programView = null; - } - } + this.updateProgram(); const program = this.programView || this.program; + if (!this.device || !program) { + return null; + } const editing = this.isEditing; const { running, enabled, schedule, sequence } = program; diff --git a/client/sprinklersRpc/WSSprinklersDevice.ts b/client/sprinklersRpc/WSSprinklersDevice.ts index 43c819c..97b032d 100644 --- a/client/sprinklersRpc/WSSprinklersDevice.ts +++ b/client/sprinklersRpc/WSSprinklersDevice.ts @@ -9,17 +9,15 @@ import { log, WebSocketRpcClient } from "./WebSocketRpcClient"; // tslint:disable:member-ordering export class WSSprinklersDevice extends s.SprinklersDevice { readonly api: WebSocketRpcClient; - private _id: string; + constructor(api: WebSocketRpcClient, id: string) { - super(); + super(api, id); this.api = api; - this._id = id; + autorun(this.updateConnectionState); this.waitSubscribe(); } - get id() { - return this._id; - } + private updateConnectionState = () => { const { clientToServer, serverToBroker } = this.api.connectionState; runInAction("updateConnectionState", () => { @@ -34,7 +32,7 @@ export class WSSprinklersDevice extends s.SprinklersDevice { try { await this.api.makeRequest("deviceSubscribe", subscribeRequest); runInAction("deviceSubscribeSuccess", () => { - this.connectionState.brokerToDevice = true; + this.connectionState.hasPermission = true; }); } catch (err) { runInAction("deviceSubscribeError", () => { @@ -48,6 +46,20 @@ export class WSSprinklersDevice extends s.SprinklersDevice { } } + async unsubscribe() { + const unsubscribeRequest: ws.IDeviceSubscribeRequest = { + deviceId: this.id, + }; + try { + await this.api.makeRequest("deviceUnsubscribe", unsubscribeRequest); + runInAction("deviceUnsubscribeSuccess", () => { + this.connectionState.brokerToDevice = false; + }); + } catch (err) { + log.error({ err }, "error unsubscribing from device"); + } + } + makeRequest(request: deviceRequests.Request): Promise { return this.api.makeDeviceCall(this.id, request); } diff --git a/client/sprinklersRpc/WebSocketRpcClient.ts b/client/sprinklersRpc/WebSocketRpcClient.ts index ff0bb0f..9e98840 100644 --- a/client/sprinklersRpc/WebSocketRpcClient.ts +++ b/client/sprinklersRpc/WebSocketRpcClient.ts @@ -1,4 +1,4 @@ -import { action, observable, runInAction, when } from "mobx"; +import { action, computed, observable, runInAction, when } from "mobx"; import { update } from "serializr"; import { TokenStore } from "@client/state/TokenStore"; @@ -6,12 +6,12 @@ import { ErrorCode } from "@common/ErrorCode"; import { IUser } from "@common/httpApi"; import * as rpc from "@common/jsonRpc"; import logger from "@common/logger"; +import * as s from "@common/sprinklersRpc"; import * as deviceRequests from "@common/sprinklersRpc/deviceRequests"; -import * as s from "@common/sprinklersRpc/index"; -import * as schema from "@common/sprinklersRpc/schema/index"; +import * as schema from "@common/sprinklersRpc/schema/"; import { seralizeRequest } from "@common/sprinklersRpc/schema/requests"; import * as ws from "@common/sprinklersRpc/websocketData"; -import { DefaultEvents, TypedEventEmitter } from "@common/TypedEventEmitter"; +import { DefaultEvents, TypedEventEmitter, typedEventEmitter } from "@common/TypedEventEmitter"; import { WSSprinklersDevice } from "./WSSprinklersDevice"; export const log = logger.child({ source: "websocket" }); @@ -27,15 +27,22 @@ const DEFAULT_URL = `${websocketProtocol}//${location.hostname}:${websocketPort} export interface WebSocketRpcClientEvents extends DefaultEvents { newUserData(userData: IUser): void; - rpcError(error: ws.RpcError): void; - tokenError(error: ws.RpcError): void; + rpcError(error: s.RpcError): void; + tokenError(error: s.RpcError): void; } -export class WebSocketRpcClient extends TypedEventEmitter implements s.SprinklersRPC { +// tslint:disable:member-ordering +export interface WebSocketRpcClient extends TypedEventEmitter { +} + +@typedEventEmitter +export class WebSocketRpcClient extends s.SprinklersRPC { + @computed get connected(): boolean { return this.connectionState.isServerConnected || false; } + readonly webSocketUrl: string; devices: Map = new Map(); @@ -51,7 +58,6 @@ export class WebSocketRpcClient extends TypedEventEmitter { this.connectionState.serverToBroker = null; @@ -68,7 +74,7 @@ export class WebSocketRpcClient extends TypedEventEmitter { + this.on("rpcError", (err: s.RpcError) => { if (err.code === ErrorCode.BadToken) { this.emit("tokenError", err); } @@ -90,7 +96,9 @@ export class WebSocketRpcClient extends TypedEventEmitter { + log.debug({ id }, "released device"); + this.devices.delete(id); + }); } async authenticate(accessToken: string): Promise { @@ -120,7 +134,7 @@ export class WebSocketRpcClient extends TypedEventEmitter { + runInAction("authenticateError", () => { this.authenticated = false; }); } @@ -134,13 +148,13 @@ export class WebSocketRpcClient extends TypedEventEmitter { delete this.responseCallbacks[id]; - reject(new ws.RpcError("the request timed out", ErrorCode.Timeout)); + reject(new s.RpcError("the request timed out", ErrorCode.Timeout)); }, TIMEOUT_MS); this.sendRequest(id, method, params); }) .catch((err) => { - if (err instanceof ws.RpcError) { + if (err instanceof s.RpcError) { this.emit("rpcError", err); } throw err; diff --git a/common/TypedEventEmitter.ts b/common/TypedEventEmitter.ts index a67f957..9156fd9 100644 --- a/common/TypedEventEmitter.ts +++ b/common/TypedEventEmitter.ts @@ -45,4 +45,20 @@ const TypedEventEmitter = EventEmitter as { }; type TypedEventEmitter = ITypedEventEmitter; +type Constructable = new (...args: any[]) => any; + +export function typedEventEmitter(Base: TBase): + TBase & TypedEventEmitter { + const NewClass = class extends Base { + constructor(...args: any[]) { + super(...args); + EventEmitter.call(this); + } + }; + Object.getOwnPropertyNames(EventEmitter.prototype).forEach((name) => { + NewClass.prototype[name] = (EventEmitter.prototype as any)[name]; + }); + return NewClass as any; +} + export { TypedEventEmitter }; diff --git a/common/jsonRpc/index.ts b/common/jsonRpc/index.ts index 694c369..85eee17 100644 --- a/common/jsonRpc/index.ts +++ b/common/jsonRpc/index.ts @@ -126,30 +126,35 @@ export async function handleRequest( handlers: RequestHandlers, message: Request, + thisParam?: any, ): Promise> { const handler = handlers[message.method]; if (!handler) { throw new Error("No handler for request method " + message.method); } - return handler(message.params); + return handler.call(thisParam, message.params); } export function handleResponse( handlers: ResponseHandlers, - message: Response) { + message: Response, + thisParam?: any, +) { const handler = handlers[message.id]; if (!handler) { return; } - return handler(message); + return handler.call(thisParam, message); } export function handleNotification( handlers: NotificationHandlers, - message: Notification) { + message: Notification, + thisParam?: any, +) { const handler = handlers[message.method]; if (!handler) { throw new Error("No handler for notification method " + message.method); } - return handler(message.data); + return handler.call(thisParam, message.data); } diff --git a/common/sprinklersRpc/RpcError.ts b/common/sprinklersRpc/RpcError.ts new file mode 100644 index 0000000..e2a15c4 --- /dev/null +++ b/common/sprinklersRpc/RpcError.ts @@ -0,0 +1,21 @@ +import { ErrorCode } from "@common/ErrorCode"; +import { IError } from "./websocketData"; + +export class RpcError extends Error implements IError { + name = "RpcError"; + code: number; + data: any; + + constructor(message: string, code: number = ErrorCode.BadRequest, data: any = {}) { + super(message); + this.code = code; + if (data instanceof Error) { + this.data = data.toString(); + } + this.data = data; + } + + toJSON(): IError { + return { code: this.code, message: this.message, data: this.data }; + } +} diff --git a/common/sprinklersRpc/SprinklersDevice.ts b/common/sprinklersRpc/SprinklersDevice.ts index 666a0dc..c2436f9 100644 --- a/common/sprinklersRpc/SprinklersDevice.ts +++ b/common/sprinklersRpc/SprinklersDevice.ts @@ -4,8 +4,12 @@ import * as req from "./deviceRequests"; import { Program } from "./Program"; import { Section } from "./Section"; import { SectionRunner } from "./SectionRunner"; +import { SprinklersRPC } from "./SprinklersRPC"; export abstract class SprinklersDevice { + readonly rpc: SprinklersRPC; + readonly id: string; + @observable connectionState: ConnectionState = new ConnectionState(); @observable sections: Section[] = []; @observable programs: Program[] = []; @@ -19,14 +23,37 @@ export abstract class SprinklersDevice { sectionRunnerConstructor: typeof SectionRunner = SectionRunner; programConstructor: typeof Program = Program; - protected constructor() { + private references: number = 0; + + protected constructor(rpc: SprinklersRPC, id: string) { + this.rpc = rpc; + this.id = id; this.sectionRunner = new (this.sectionRunnerConstructor)(this); } - abstract get id(): string; - abstract makeRequest(request: req.Request): Promise; + /** + * Increase the reference count for this sprinklers device + * @returns The new reference count + */ + acquire(): number { + return ++this.references; + } + + /** + * Releases one reference to this device. When the reference count reaches 0, the device + * will be released and no longer updated. + * @returns The reference count after being updated + */ + release(): number { + this.references--; + if (this.references <= 0) { + this.rpc.releaseDevice(this.id); + } + return this.references; + } + runProgram(opts: req.WithProgram) { return this.makeRequest({ ...opts, type: "runProgram" }); } diff --git a/common/sprinklersRpc/SprinklersRPC.ts b/common/sprinklersRpc/SprinklersRPC.ts index 938e55d..fc17115 100644 --- a/common/sprinklersRpc/SprinklersRPC.ts +++ b/common/sprinklersRpc/SprinklersRPC.ts @@ -1,13 +1,31 @@ import { ConnectionState } from "./ConnectionState"; import { SprinklersDevice } from "./SprinklersDevice"; -export interface SprinklersRPC { - readonly connectionState: ConnectionState; - readonly connected: boolean; +export abstract class SprinklersRPC { + abstract readonly connectionState: ConnectionState; + abstract readonly connected: boolean; - start(): void; + abstract start(): void; - getDevice(id: string): SprinklersDevice; + /** + * Acquires a reference to a device. This reference must be released by calling + * SprinklersDevice#release for every time this method was called + * @param id The id of the device + */ + acquireDevice(id: string): SprinklersDevice { + const device = this.getDevice(id); + device.acquire(); + return device; + } - removeDevice(id: string): void; + /** + * Forces a device to be released. The device will no longer be updated. + * + * This should not be used normally, instead SprinklersDevice#release should be called to manage + * each reference to a device. + * @param id The id of the device to remove + */ + abstract releaseDevice(id: string): void; + + protected abstract getDevice(id: string): SprinklersDevice; } diff --git a/common/sprinklersRpc/index.ts b/common/sprinklersRpc/index.ts index b615c38..e059723 100644 --- a/common/sprinklersRpc/index.ts +++ b/common/sprinklersRpc/index.ts @@ -1,4 +1,3 @@ -// export * from "./Duration"; export * from "./SprinklersRPC"; export * from "./Program"; export * from "./schedule"; @@ -6,3 +5,4 @@ export * from "./Section"; export * from "./SectionRunner"; export * from "./SprinklersDevice"; export * from "./ConnectionState"; +export * from "./RpcError"; diff --git a/common/sprinklersRpc/mqtt/index.ts b/common/sprinklersRpc/mqtt/index.ts index b6522b1..b28f606 100644 --- a/common/sprinklersRpc/mqtt/index.ts +++ b/common/sprinklersRpc/mqtt/index.ts @@ -1,9 +1,11 @@ import { autorun, observable } from "mobx"; import * as mqtt from "mqtt"; +import { ErrorCode } from "@common/ErrorCode"; import logger from "@common/logger"; import * as s from "@common/sprinklersRpc"; import * as requests from "@common/sprinklersRpc/deviceRequests"; +import { RpcError } from "@common/sprinklersRpc/RpcError"; import { seralizeRequest } from "@common/sprinklersRpc/schema/requests"; import { getRandomId } from "@common/utils"; @@ -18,6 +20,7 @@ interface WithRid { } export const DEVICE_PREFIX = "devices"; +const REQUEST_TIMEOUT = 5000; export interface MqttRpcClientOptions { mqttUri: string; @@ -25,7 +28,7 @@ export interface MqttRpcClientOptions { password?: string; } -export class MqttRpcClient implements s.SprinklersRPC, MqttRpcClientOptions { +export class MqttRpcClient extends s.SprinklersRPC implements MqttRpcClientOptions { get connected(): boolean { return this.connectionState.isServerConnected || false; } @@ -43,6 +46,7 @@ export class MqttRpcClient implements s.SprinklersRPC, MqttRpcClientOptions { devices: Map = new Map(); constructor(opts: MqttRpcClientOptions) { + super(); Object.assign(this, opts); this.connectionState.serverToBroker = false; } @@ -69,7 +73,16 @@ export class MqttRpcClient implements s.SprinklersRPC, MqttRpcClientOptions { }); } - getDevice(id: string): s.SprinklersDevice { + releaseDevice(id: string) { + const device = this.devices.get(id); + if (!device) { + return; + } + device.doUnsubscribe(); + this.devices.delete(id); + } + + protected getDevice(id: string): s.SprinklersDevice { if (/\//.test(id)) { throw new Error("Device id cannot contain a /"); } @@ -83,15 +96,6 @@ export class MqttRpcClient implements s.SprinklersRPC, MqttRpcClientOptions { return device; } - removeDevice(id: string) { - const device = this.devices.get(id); - if (!device) { - return; - } - device.doUnsubscribe(); - this.devices.delete(id); - } - private onMessageArrived(topic: string, payload: Buffer, packet: mqtt.Packet) { try { this.processMessage(topic, payload, packet); @@ -119,6 +123,8 @@ export class MqttRpcClient implements s.SprinklersRPC, MqttRpcClientOptions { } } +type ResponseCallback = (response: requests.Response) => void; + const subscriptions = [ "/connected", "/sections", @@ -148,7 +154,6 @@ const handler = (test: RegExp) => class MqttSprinklersDevice extends s.SprinklersDevice { readonly apiClient: MqttRpcClient; - readonly id: string; handlers!: IHandlerEntry[]; private subscriptions: string[]; @@ -156,12 +161,11 @@ class MqttSprinklersDevice extends s.SprinklersDevice { private responseCallbacks: Map = new Map(); constructor(apiClient: MqttRpcClient, id: string) { - super(); + super(apiClient, id); this.sectionConstructor = MqttSection; this.sectionRunnerConstructor = MqttSectionRunner; this.programConstructor = MqttProgram; this.apiClient = apiClient; - this.id = id; this.sectionRunner = new MqttSectionRunner(this); this.subscriptions = subscriptions.map((filter) => this.prefix + filter); @@ -183,27 +187,23 @@ class MqttSprinklersDevice extends s.SprinklersDevice { return DEVICE_PREFIX + "/" + this.id; } - doSubscribe(): Promise { - return new Promise((resolve, reject) => { - this.apiClient.client.subscribe(this.subscriptions, { qos: 1 }, (err) => { - if (err) { - reject(err); - } else { - resolve(); - } - }); + doSubscribe() { + this.apiClient.client.subscribe(this.subscriptions, { qos: 1 }, (err) => { + if (err) { + log.error({ err, id: this.id }, "error subscribing to device"); + } else { + log.debug({ id: this.id }, "subscribed to device"); + } }); } - doUnsubscribe(): Promise { - return new Promise((resolve, reject) => { - this.apiClient.client.unsubscribe(this.subscriptions, (err) => { - if (err) { - reject(err); - } else { - resolve(); - } - }); + doUnsubscribe() { + this.apiClient.client.unsubscribe(this.subscriptions, (err) => { + if (err) { + log.error({ err, id: this.id }, "error unsubscribing to device"); + } else { + log.debug({ id: this.id }, "unsubscribed to device"); + } }); } @@ -226,14 +226,25 @@ class MqttSprinklersDevice extends s.SprinklersDevice { const json = seralizeRequest(request); const requestId = json.rid = this.getRequestId(); const payloadStr = JSON.stringify(json); - this.responseCallbacks.set(requestId, (data) => { + + let timeoutHandle: any; + const callback: ResponseCallback = (data) => { if (data.result === "error") { - reject(data); + reject(new RpcError(data.message, data.code, data)); } else { resolve(data); } this.responseCallbacks.delete(requestId); - }); + clearTimeout(timeoutHandle); + }; + + timeoutHandle = setTimeout(() => { + reject(new RpcError("the request has timed out", ErrorCode.Timeout)); + this.responseCallbacks.delete(requestId); + clearTimeout(timeoutHandle); + }, REQUEST_TIMEOUT); + + this.responseCallbacks.set(requestId, callback); this.apiClient.client.publish(topic, payloadStr, { qos: 1 }); }); } @@ -298,5 +309,3 @@ class MqttSprinklersDevice extends s.SprinklersDevice { /* tslint:enable:no-unused-variable */ } - -type ResponseCallback = (response: requests.Response) => void; diff --git a/common/sprinklersRpc/websocketData.ts b/common/sprinklersRpc/websocketData.ts index 4235b81..1cc7197 100644 --- a/common/sprinklersRpc/websocketData.ts +++ b/common/sprinklersRpc/websocketData.ts @@ -1,6 +1,5 @@ import * as rpc from "../jsonRpc/index"; -import { ErrorCode } from "@common/ErrorCode"; import { IUser } from "@common/httpApi"; import { Response as ResponseData } from "@common/sprinklersRpc/deviceRequests"; @@ -20,6 +19,7 @@ export interface IDeviceCallRequest { export interface IClientRequestTypes { "authenticate": IAuthenticateRequest; "deviceSubscribe": IDeviceSubscribeRequest; + "deviceUnsubscribe": IDeviceSubscribeRequest; "deviceCall": IDeviceCallRequest; } @@ -40,6 +40,7 @@ export interface IDeviceCallResponse { export interface IServerResponseTypes { "authenticate": IAuthenticateResponse; "deviceSubscribe": IDeviceSubscribeResponse; + "deviceUnsubscribe": IDeviceSubscribeResponse; "deviceCall": IDeviceCallResponse; } @@ -64,25 +65,6 @@ export type ServerNotificationMethod = keyof IServerNotificationTypes; export type IError = rpc.DefaultErrorType; export type ErrorData = rpc.ErrorData; -export class RpcError extends Error implements IError { - name = "RpcError"; - code: number; - data: any; - - constructor(message: string, code: number = ErrorCode.BadRequest, data: any = {}) { - super(message); - this.code = code; - if (data instanceof Error) { - this.data = data.toString(); - } - this.data = data; - } - - toJSON(): IError { - return { code: this.code, message: this.message, data: this.data }; - } -} - export type ServerMessage = rpc.Message<{}, IServerResponseTypes, IError, IServerNotificationTypes>; export type ServerNotification = rpc.Notification; export type ServerResponse = rpc.Response; diff --git a/common/utils.ts b/common/utils.ts index bfd1896..85cba02 100644 --- a/common/utils.ts +++ b/common/utils.ts @@ -16,3 +16,11 @@ export function checkedIndexOf(o: T | number, arr: T[], type: string = "objec export function getRandomId() { return Math.floor(Math.random() * 1000000000); } + +export function applyMixins(derivedCtor: any, baseCtors: any[]) { + baseCtors.forEach((baseCtor) => { + Object.getOwnPropertyNames(baseCtor.prototype).forEach((name) => { + derivedCtor.prototype[name] = baseCtor.prototype[name]; + }); + }); +} diff --git a/server/express/api/devices.ts b/server/express/api/devices.ts index bbf37b0..fcffd82 100644 --- a/server/express/api/devices.ts +++ b/server/express/api/devices.ts @@ -36,9 +36,10 @@ export function devices(state: ServerState) { if (!userDevice) { throw new ApiError("User does not have access to the specified device", ErrorCode.NoPermission); } - const device = state.mqttClient.getDevice(req.params.deviceId); + const device = state.mqttClient.acquireDevice(req.params.deviceId); const j = serialize(schema.sprinklersDevice, device); res.send(j); + device.release(); }); router.post("/register", verifyAuthorization({ diff --git a/server/express/authentication.ts b/server/express/authentication.ts index c4309ce..f8acd0f 100644 --- a/server/express/authentication.ts +++ b/server/express/authentication.ts @@ -10,14 +10,14 @@ import { TokenGrantRequest, TokenGrantResponse, } from "@common/httpApi"; -import { AccessToken, DeviceRegistrationToken, DeviceToken, RefreshToken, TokenClaims, SuperuserToken } from "@common/TokenClaims"; +import * as tok from "@common/TokenClaims"; import { User } from "../entities"; import { ServerState } from "../state"; declare global { namespace Express { interface Request { - token?: AccessToken; + token?: tok.AccessToken; } } } @@ -39,7 +39,7 @@ function getExpTime(lifetime: number) { return Math.floor(Date.now() / 1000) + lifetime; } -function signToken(claims: TokenClaims): Promise { +function signToken(claims: tok.TokenClaims): Promise { return new Promise((resolve, reject) => { jwt.sign(claims, JWT_SECRET, (err: Error, encoded: string) => { if (err) { @@ -51,7 +51,7 @@ function signToken(claims: TokenClaims): Promise { }); } -export function verifyToken( +export function verifyToken( token: string, type?: TClaims["type"], ): Promise { return new Promise((resolve, reject) => { @@ -67,7 +67,7 @@ export function verifyToken( reject(err); } } else { - const claims: TokenClaims = decoded as any; + const claims: tok.TokenClaims = decoded as any; if (type != null && claims.type !== type) { reject(new ApiError(`Expected a "${type}" token, received a "${claims.type}" token`, ErrorCode.BadToken)); @@ -79,7 +79,7 @@ export function verifyToken( } function generateAccessToken(user: User, secret: string): Promise { - const access_token_claims: AccessToken = { + const access_token_claims: tok.AccessToken = { iss: ISSUER, aud: user.id, name: user.name, @@ -91,7 +91,7 @@ function generateAccessToken(user: User, secret: string): Promise { } function generateRefreshToken(user: User, secret: string): Promise { - const refresh_token_claims: RefreshToken = { + const refresh_token_claims: tok.RefreshToken = { iss: ISSUER, aud: user.id, name: user.name, @@ -103,7 +103,7 @@ function generateRefreshToken(user: User, secret: string): Promise { } function generateDeviceRegistrationToken(secret: string): Promise { - const device_reg_token_claims: DeviceRegistrationToken = { + const device_reg_token_claims: tok.DeviceRegistrationToken = { iss: ISSUER, type: "device_reg", }; @@ -111,7 +111,7 @@ function generateDeviceRegistrationToken(secret: string): Promise { } export function generateDeviceToken(id: number, deviceId: string): Promise { - const device_token_claims: DeviceToken = { + const device_token_claims: tok.DeviceToken = { iss: ISSUER, type: "device", aud: deviceId, @@ -121,7 +121,7 @@ export function generateDeviceToken(id: number, deviceId: string): Promise { - const superuser_claims: SuperuserToken = { + const superuser_claims: tok.SuperuserToken = { iss: ISSUER, type: "superuser", }; @@ -200,7 +200,7 @@ export function authentication(state: ServerState) { } export interface VerifyAuthorizationOpts { - type: TokenClaims["type"]; + type: tok.TokenClaims["type"]; } export function verifyAuthorization(options?: Partial): Express.RequestHandler { diff --git a/server/index.ts b/server/index.ts index cc22a84..e40891b 100644 --- a/server/index.ts +++ b/server/index.ts @@ -10,7 +10,7 @@ import * as WebSocket from "ws"; import { ServerState } from "./state"; import { createApp } from "./express"; -import { WebSocketApi } from "./sprinklersRpc/websocketServer"; +import { WebSocketApi } from "./sprinklersRpc/WebSocketApi"; const state = new ServerState(); const app = createApp(state); diff --git a/server/sprinklersRpc/WebSocketApi.ts b/server/sprinklersRpc/WebSocketApi.ts new file mode 100644 index 0000000..96aa093 --- /dev/null +++ b/server/sprinklersRpc/WebSocketApi.ts @@ -0,0 +1,26 @@ +import * as WebSocket from "ws"; + +import { ServerState } from "@server/state"; +import { WebSocketConnection } from "./WebSocketConnection"; + +export class WebSocketApi { + state: ServerState; + clients: Set = new Set(); + + constructor(state: ServerState) { + this.state = state; + } + + listen(webSocketServer: WebSocket.Server) { + webSocketServer.on("connection", this.handleConnection); + } + + handleConnection = (socket: WebSocket) => { + const client = new WebSocketConnection(this, socket); + this.clients.add(client); + } + + removeClient(client: WebSocketConnection) { + return this.clients.delete(client); + } +} diff --git a/server/sprinklersRpc/WebSocketConnection.ts b/server/sprinklersRpc/WebSocketConnection.ts new file mode 100644 index 0000000..5897332 --- /dev/null +++ b/server/sprinklersRpc/WebSocketConnection.ts @@ -0,0 +1,262 @@ +import { autorun } from "mobx"; +import { serialize } from "serializr"; +import * as WebSocket from "ws"; + +import { ErrorCode } from "@common/ErrorCode"; +import * as rpc from "@common/jsonRpc"; +import log from "@common/logger"; +import { RpcError } from "@common/sprinklersRpc"; +import * as deviceRequests from "@common/sprinklersRpc/deviceRequests"; +import * as schema from "@common/sprinklersRpc/schema"; +import * as ws from "@common/sprinklersRpc/websocketData"; +import { AccessToken } from "@common/TokenClaims"; +import { User } from "@server/entities"; +import { verifyToken } from "@server/express/authentication"; + +import { WebSocketApi } from "./WebSocketApi"; + +type Disposer = () => void; + +export class WebSocketConnection { + api: WebSocketApi; + socket: WebSocket; + + disposers: Array<() => void> = []; + // map of device id to disposer function + deviceSubscriptions: Map = new Map(); + + /// This shall be the user id if the client has been authenticated, null otherwise + userId: number | null = null; + user: User | null = null; + + private requestHandlers: ws.ClientRequestHandlers = new WebSocketRequestHandlers(); + + get state() { + return this.api.state; + } + + constructor(api: WebSocketApi, socket: WebSocket) { + this.api = api; + this.socket = socket; + + this.socket.on("message", this.handleSocketMessage); + this.socket.on("close", this.onClose); + } + + stop = () => { + this.socket.close(); + } + + onClose = (code: number, reason: string) => { + log.debug({ code, reason }, "WebSocketConnection closing"); + this.disposers.forEach((disposer) => disposer()); + this.deviceSubscriptions.forEach((disposer) => disposer()); + this.api.removeClient(this); + } + + subscribeBrokerConnection() { + this.disposers.push(autorun(() => { + const updateData: ws.IBrokerConnectionUpdate = { + brokerConnected: this.state.mqttClient.connected, + }; + this.sendNotification("brokerConnectionUpdate", updateData); + })); + } + + checkAuthorization() { + if (!this.userId || !this.user) { + throw new RpcError("this WebSocket session has not been authenticated", + ErrorCode.Unauthorized); + } + } + + checkDevice(devId: string) { + const userDevice = this.user!.devices!.find((dev) => dev.deviceId === devId); + if (userDevice == null) { + throw new RpcError("you do not have permission to subscribe to device", + ErrorCode.NoPermission, { id: devId }); + } + const deviceId = userDevice.deviceId; + if (!deviceId) { + throw new RpcError("device has no associated device prefix", ErrorCode.Internal); + } + return userDevice; + } + + sendMessage(data: ws.ServerMessage) { + this.socket.send(JSON.stringify(data)); + } + + sendNotification( + method: Method, + data: ws.IServerNotificationTypes[Method]) { + this.sendMessage({ type: "notification", method, data }); + } + + sendResponse( + method: Method, + id: number, + data: ws.ServerResponseData) { + this.sendMessage({ type: "response", method, id, ...data }); + } + + handleSocketMessage = (socketData: WebSocket.Data) => { + this.doHandleSocketMessage(socketData) + .catch((err) => { + this.onError({ err }, "unhandled error on handling socket message"); + }); + } + + async doDeviceCallRequest(requestData: ws.IDeviceCallRequest): Promise { + const userDevice = this.checkDevice(requestData.deviceId); + const deviceId = userDevice.deviceId!; + const device = this.state.mqttClient.acquireDevice(deviceId); + try { + const request = schema.requests.deserializeRequest(requestData.data); + return await device.makeRequest(request); + } finally { + device.release(); + } + } + + private async doHandleSocketMessage(socketData: WebSocket.Data) { + if (typeof socketData !== "string") { + return this.onError({ type: typeof socketData }, + "received invalid socket data type from client", ErrorCode.Parse); + } + let data: ws.ClientMessage; + try { + data = JSON.parse(socketData); + } catch (err) { + return this.onError({ socketData, err }, "received invalid websocket message from client", + ErrorCode.Parse); + } + switch (data.type) { + case "request": + await this.handleRequest(data); + break; + default: + return this.onError({ data }, "received invalid message type from client", + ErrorCode.BadRequest); + } + } + + private async handleRequest(request: ws.ClientRequest) { + let response: ws.ServerResponseData; + try { + if (!this.requestHandlers[request.method]) { + // noinspection ExceptionCaughtLocallyJS + throw new RpcError("received invalid client request method"); + } + response = await rpc.handleRequest(this.requestHandlers, request, this); + } catch (err) { + if (err instanceof RpcError) { + log.debug({ err }, "rpc error"); + response = { result: "error", error: err.toJSON() }; + } else { + log.error({ method: request.method, err }, "unhandled error during processing of client request"); + response = { + result: "error", error: { + code: ErrorCode.Internal, message: "unhandled error during processing of client request", + data: err.toString(), + }, + }; + } + } + this.sendResponse(request.method, request.id, response); + } + + private onError(data: any, message: string, code: number = ErrorCode.Internal) { + log.error(data, message); + const errorData: ws.IError = { code, message, data }; + this.sendNotification("error", errorData); + } +} + +class WebSocketRequestHandlers implements ws.ClientRequestHandlers { + async authenticate(this: WebSocketConnection, data: ws.IAuthenticateRequest): + Promise> { + if (!data.accessToken) { + throw new RpcError("no token specified", ErrorCode.BadRequest); + } + let claims: AccessToken; + try { + claims = await verifyToken(data.accessToken, "access"); + } catch (e) { + throw new RpcError("invalid token", ErrorCode.BadToken, e); + } + this.userId = claims.aud; + this.user = await this.state.database.users. + findById(this.userId, { devices: true }) || null; + if (!this.user) { + throw new RpcError("user no longer exists", ErrorCode.BadToken); + } + log.debug({ userId: claims.aud, name: claims.name }, "authenticated websocket client"); + this.subscribeBrokerConnection(); + return { + result: "success", + data: { authenticated: true, message: "authenticated", user: this.user.toJSON() }, + }; + } + + async deviceSubscribe(this: WebSocketConnection, data: ws.IDeviceSubscribeRequest): + Promise> { + this.checkAuthorization(); + const userDevice = this.checkDevice(data.deviceId); + const deviceId = userDevice.deviceId!; + if (!this.deviceSubscriptions.has(deviceId)) { + const device = this.state.mqttClient.acquireDevice(deviceId); + log.debug({ deviceId, userId: this.userId }, "websocket client subscribed to device"); + + const autorunDisposer = autorun(() => { + const json = serialize(schema.sprinklersDevice, device); + log.trace({ device: json }); + const updateData: ws.IDeviceUpdate = { deviceId, data: json }; + this.sendNotification("deviceUpdate", updateData); + }, { delay: 100 }); + + this.deviceSubscriptions.set(deviceId, () => { + autorunDisposer(); + device.release(); + this.deviceSubscriptions.delete(deviceId); + }); + } + + const response: ws.IDeviceSubscribeResponse = { + deviceId, + }; + return { result: "success", data: response }; + } + + async deviceUnsubscribe(this: WebSocketConnection, data: ws.IDeviceSubscribeRequest): + Promise> { + this.checkAuthorization(); + const userDevice = this.checkDevice(data.deviceId); + const deviceId = userDevice.deviceId!; + const disposer = this.deviceSubscriptions.get(deviceId); + + if (disposer) { + disposer(); + } + + const response: ws.IDeviceSubscribeResponse = { + deviceId, + }; + return { result: "success", data: response }; + } + + async deviceCall(this: WebSocketConnection, data: ws.IDeviceCallRequest): + Promise> { + this.checkAuthorization(); + try { + const response = await this.doDeviceCallRequest(data); + const resData: ws.IDeviceCallResponse = { + data: response, + }; + return { result: "success", data: resData }; + } catch (err) { + const e: deviceRequests.ErrorResponseData = err; + throw new RpcError(e.message, e.code, e); + } + } +} diff --git a/server/sprinklersRpc/websocketServer.ts b/server/sprinklersRpc/websocketServer.ts deleted file mode 100644 index 3cf2981..0000000 --- a/server/sprinklersRpc/websocketServer.ts +++ /dev/null @@ -1,247 +0,0 @@ -import { autorun } from "mobx"; -import { serialize } from "serializr"; -import * as WebSocket from "ws"; - -import { ErrorCode } from "@common/ErrorCode"; -import * as rpc from "@common/jsonRpc"; -import log from "@common/logger"; -import * as deviceRequests from "@common/sprinklersRpc/deviceRequests"; -import * as schema from "@common/sprinklersRpc/schema"; -import * as ws from "@common/sprinklersRpc/websocketData"; -import { AccessToken } from "@common/TokenClaims"; -import { User } from "@server/entities"; -import { verifyToken } from "@server/express/authentication"; -import { ServerState } from "@server/state"; - -// tslint:disable:member-ordering - -export class WebSocketClient { - api: WebSocketApi; - socket: WebSocket; - - disposers: Array<() => void> = []; - deviceSubscriptions: string[] = []; - - /// This shall be the user id if the client has been authenticated, null otherwise - userId: number | null = null; - user: User | null = null; - - get state() { - return this.api.state; - } - - constructor(api: WebSocketApi, socket: WebSocket) { - this.api = api; - this.socket = socket; - } - - start() { - this.socket.on("message", this.handleSocketMessage); - this.socket.on("close", this.stop); - } - - stop = () => { - this.disposers.forEach((disposer) => disposer()); - this.api.removeClient(this); - } - - private subscribeBrokerConnection() { - this.disposers.push(autorun(() => { - const updateData: ws.IBrokerConnectionUpdate = { - brokerConnected: this.state.mqttClient.connected, - }; - this.sendNotification("brokerConnectionUpdate", updateData); - })); - } - - private checkAuthorization() { - if (!this.userId || !this.user) { - throw new ws.RpcError("this WebSocket session has not been authenticated", - ErrorCode.Unauthorized); - } - } - - private checkDevice(devId: string) { - const userDevice = this.user!.devices!.find((dev) => dev.deviceId === devId); - if (userDevice == null) { - throw new ws.RpcError("you do not have permission to subscribe to this device", - ErrorCode.NoPermission); - } - const deviceId = userDevice.deviceId; - if (!deviceId) { - throw new ws.RpcError("device has no associated device prefix", ErrorCode.BadRequest); - } - return userDevice; - } - - private requestHandlers: ws.ClientRequestHandlers = { - authenticate: async (data: ws.IAuthenticateRequest) => { - if (!data.accessToken) { - throw new ws.RpcError("no token specified", ErrorCode.BadRequest); - } - let claims: AccessToken; - try { - claims = await verifyToken(data.accessToken, "access"); - } catch (e) { - throw new ws.RpcError("invalid token", ErrorCode.BadToken, e); - } - this.userId = claims.aud; - this.user = await this.state.database.users. - findById(this.userId, { devices: true }) || null; - if (!this.user) { - throw new ws.RpcError("user no longer exists", ErrorCode.BadToken); - } - log.info({ userId: claims.aud, name: claims.name }, "authenticated websocket client"); - this.subscribeBrokerConnection(); - return { - result: "success", - data: { authenticated: true, message: "authenticated", user: this.user.toJSON() }, - }; - }, - deviceSubscribe: async (data: ws.IDeviceSubscribeRequest) => { - this.checkAuthorization(); - const userDevice = this.checkDevice(data.deviceId); - const deviceId = userDevice.deviceId!; - if (this.deviceSubscriptions.indexOf(deviceId) === -1) { - this.deviceSubscriptions.push(deviceId); - const device = this.state.mqttClient.getDevice(deviceId); - log.debug({ deviceId, userId: this.userId }, "websocket client subscribed to device"); - this.disposers.push(autorun(() => { - const json = serialize(schema.sprinklersDevice, device); - log.trace({ device: json }); - const updateData: ws.IDeviceUpdate = { deviceId, data: json }; - this.sendNotification("deviceUpdate", updateData); - }, { delay: 100 })); - } - - const response: ws.IDeviceSubscribeResponse = { - deviceId, - }; - return { result: "success", data: response }; - }, - deviceCall: async (data: ws.IDeviceCallRequest) => { - this.checkAuthorization(); - try { - const response = await this.doDeviceCallRequest(data); - const resData: ws.IDeviceCallResponse = { - data: response, - }; - return { result: "success", data: resData }; - } catch (err) { - const e: deviceRequests.ErrorResponseData = err; - throw new ws.RpcError(e.message, e.code, e); - } - }, - }; - - private sendMessage(data: ws.ServerMessage) { - this.socket.send(JSON.stringify(data)); - } - - private sendNotification( - method: Method, - data: ws.IServerNotificationTypes[Method]) { - this.sendMessage({ type: "notification", method, data }); - } - - private sendResponse( - method: Method, - id: number, - data: ws.ServerResponseData) { - this.sendMessage({ type: "response", method, id, ...data }); - } - - private handleSocketMessage = (socketData: WebSocket.Data) => { - this.doHandleSocketMessage(socketData) - .catch((err) => { - this.onError({ err }, "unhandled error on handling socket message"); - }); - } - - private async doHandleSocketMessage(socketData: WebSocket.Data) { - if (typeof socketData !== "string") { - return this.onError({ type: typeof socketData }, - "received invalid socket data type from client", ErrorCode.Parse); - } - let data: ws.ClientMessage; - try { - data = JSON.parse(socketData); - } catch (err) { - return this.onError({ socketData, err }, "received invalid websocket message from client", - ErrorCode.Parse); - } - switch (data.type) { - case "request": - await this.handleRequest(data); - break; - default: - return this.onError({ data }, "received invalid message type from client", - ErrorCode.BadRequest); - } - } - - private async handleRequest(request: ws.ClientRequest) { - let response: ws.ServerResponseData; - try { - if (!this.requestHandlers[request.method]) { - // noinspection ExceptionCaughtLocallyJS - throw new ws.RpcError("received invalid client request method"); - } - response = await rpc.handleRequest(this.requestHandlers, request); - } catch (err) { - if (err instanceof ws.RpcError) { - log.debug({ err }, "rpc error"); - response = { result: "error", error: err.toJSON() }; - } else { - log.error({ method: request.method, err }, "unhandled error during processing of client request"); - response = { - result: "error", error: { - code: ErrorCode.Internal, message: "unhandled error during processing of client request", - data: err.toString(), - }, - }; - } - } - this.sendResponse(request.method, request.id, response); - } - - private onError(data: any, message: string, code: number = ErrorCode.Internal) { - log.error(data, message); - const errorData: ws.IError = { code, message, data }; - this.sendNotification("error", errorData); - } - - private async doDeviceCallRequest(requestData: ws.IDeviceCallRequest): Promise { - const userDevice = this.checkDevice(requestData.deviceId); - const deviceId = userDevice.deviceId!; - const device = this.state.mqttClient.getDevice(deviceId); - const request = schema.requests.deserializeRequest(requestData.data); - return device.makeRequest(request); - } -} - -export class WebSocketApi { - state: ServerState; - clients: WebSocketClient[] = []; - - constructor(state: ServerState) { - this.state = state; - } - - listen(webSocketServer: WebSocket.Server) { - webSocketServer.on("connection", this.handleConnection); - } - - handleConnection = (socket: WebSocket) => { - const client = new WebSocketClient(this, socket); - client.start(); - this.clients.push(client); - } - - removeClient(client: WebSocketClient) { - const idx = this.clients.indexOf(client); - if (idx !== -1) { - this.clients.splice(idx, 1); - } - } -}