From f8a1dd0a8c2d4ad336a9f5d5fe0aedce2283002d Mon Sep 17 00:00:00 2001 From: Alex Mikhalev Date: Fri, 17 Aug 2018 13:48:34 -0600 Subject: [PATCH] Better handling of expired access tokens --- client/sprinklersRpc/WebSocketRpcClient.ts | 72 +++++++++++++--------- client/state/AppState.ts | 41 +++++++++--- client/state/HttpApi.ts | 63 +++++++++++++------ client/state/Token.ts | 2 +- client/state/UserStore.ts | 2 +- common/TypedEventEmitter.ts | 48 +++++++++++++++ server/Database.ts | 1 + 7 files changed, 171 insertions(+), 58 deletions(-) create mode 100644 common/TypedEventEmitter.ts diff --git a/client/sprinklersRpc/WebSocketRpcClient.ts b/client/sprinklersRpc/WebSocketRpcClient.ts index 10a63bd..0f19ba5 100644 --- a/client/sprinklersRpc/WebSocketRpcClient.ts +++ b/client/sprinklersRpc/WebSocketRpcClient.ts @@ -2,8 +2,8 @@ import { action, autorun, observable, runInAction, when } from "mobx"; import { update } from "serializr"; import { TokenStore } from "@client/state/TokenStore"; -import { UserStore } from "@client/state/UserStore"; import { ErrorCode } from "@common/ErrorCode"; +import { IUser } from "@common/httpApi"; import * as rpc from "@common/jsonRpc"; import logger from "@common/logger"; import * as deviceRequests from "@common/sprinklersRpc/deviceRequests"; @@ -11,6 +11,7 @@ import * as s from "@common/sprinklersRpc/index"; import * as schema from "@common/sprinklersRpc/schema/index"; import { seralizeRequest } from "@common/sprinklersRpc/schema/requests"; import * as ws from "@common/sprinklersRpc/websocketData"; +import { DefaultEvents, TypedEventEmitter } from "@common/TypedEventEmitter"; const log = logger.child({ source: "websocket" }); @@ -83,7 +84,13 @@ export class WSSprinklersDevice extends s.SprinklersDevice { } } -export class WebSocketRpcClient implements s.SprinklersRPC { +export interface WebSocketRpcClientEvents extends DefaultEvents { + newUserData(userData: IUser): void; + rpcError(error: ws.RpcError): void; + tokenError(error: ws.RpcError): void; +} + +export class WebSocketRpcClient extends TypedEventEmitter implements s.SprinklersRPC { readonly webSocketUrl: string; devices: Map = new Map(); @@ -94,7 +101,6 @@ export class WebSocketRpcClient implements s.SprinklersRPC { authenticated: boolean = false; tokenStore: TokenStore; - userStore: UserStore; private nextRequestId = Math.round(Math.random() * 1000000); private responseCallbacks: ws.ServerResponseHandlers = {}; @@ -104,12 +110,18 @@ export class WebSocketRpcClient implements s.SprinklersRPC { return this.connectionState.isServerConnected || false; } - constructor(tokenStore: TokenStore, userStore: UserStore, webSocketUrl: string = DEFAULT_URL) { + constructor(tokenStore: TokenStore, webSocketUrl: string = DEFAULT_URL) { + super(); this.webSocketUrl = webSocketUrl; this.tokenStore = tokenStore; - this.userStore = userStore; this.connectionState.clientToServer = false; this.connectionState.serverToBroker = false; + + this.on("rpcError", (err: ws.RpcError) => { + if (err.code === ErrorCode.BadToken) { + this.emit("tokenError", err); + } + }); } start() { @@ -149,13 +161,17 @@ export class WebSocketRpcClient implements s.SprinklersRPC { && this.tokenStore.accessToken.isValid, async () => { try { const res = await this.authenticate(this.tokenStore.accessToken.token!); - this.authenticated = res.authenticated; + runInAction("authenticateSuccess", () => { + this.authenticated = res.authenticated; + }); logger.info({ user: res.user }, "authenticated websocket connection"); - this.userStore.receiveUserData(res.user); + this.emit("newUserData", res.user); } catch (err) { logger.error({ err }, "error authenticating websocket connection"); // TODO message? - this.authenticated = false; + runInAction("authenticateSuccess", () => { + this.authenticated = false; + }); } }); } @@ -167,17 +183,13 @@ export class WebSocketRpcClient implements s.SprinklersRPC { code: ErrorCode.ServerDisconnected, message: "the server is not connected", }; - throw error; + throw new ws.RpcError("the server is not connected", ErrorCode.ServerDisconnected); } const requestData = seralizeRequest(request); const data: ws.IDeviceCallRequest = { deviceId, data: requestData }; const resData = await this.makeRequest("deviceCall", data); if (resData.data.result === "error") { - throw { - code: resData.data.code, - message: resData.data.message, - data: resData.data, - }; + throw new ws.RpcError(resData.data.message, resData.data.code, resData.data); } else { return resData.data; } @@ -194,21 +206,22 @@ export class WebSocketRpcClient implements s.SprinklersRPC { if (response.result === "success") { resolve(response.data); } else { - reject(response.error); + const { error } = response; + reject(new ws.RpcError(error.message, error.code, error.data)); } }; timeoutHandle = window.setTimeout(() => { delete this.responseCallbacks[id]; - const res: ws.ErrorData = { - result: "error", error: { - code: ErrorCode.Timeout, - message: "the request timed out", - }, - }; - reject(res); + reject(new ws.RpcError("the request timed out", ErrorCode.Timeout)); }, TIMEOUT_MS); this.sendRequest(id, method, params); - }); + }) + .catch((err) => { + if (err instanceof ws.RpcError) { + this.emit("rpcError", err); + } + throw err; + }); } private sendMessage(data: ws.ClientMessage) { @@ -230,7 +243,7 @@ export class WebSocketRpcClient implements s.SprinklersRPC { private _connect() { if (this.socket != null && - (this.socket.readyState === WebSocket.CLOSED)) { + (this.socket.readyState === WebSocket.OPEN)) { this.tryAuthenticate(); return; } @@ -242,6 +255,7 @@ export class WebSocketRpcClient implements s.SprinklersRPC { this.socket.onmessage = this.onMessage.bind(this); } + @action private onOpen() { log.info("established websocket connection"); this.connectionState.clientToServer = true; @@ -250,6 +264,7 @@ export class WebSocketRpcClient implements s.SprinklersRPC { } /* tslint:disable-next-line:member-ordering */ + @action private onDisconnect = action(() => { this.connectionState.serverToBroker = null; this.connectionState.clientToServer = false; @@ -263,12 +278,11 @@ export class WebSocketRpcClient implements s.SprinklersRPC { this.reconnectTimer = window.setTimeout(this._reconnect, RECONNECT_TIMEOUT_MS); } + @action private onError(event: Event) { log.error({ event }, "websocket error"); - action(() => { - this.connectionState.serverToBroker = null; - this.connectionState.clientToServer = false; - }); + this.connectionState.serverToBroker = null; + this.connectionState.clientToServer = false; this.onDisconnect(); } @@ -335,4 +349,4 @@ class WSClientNotificationHandlers implements ws.ServerNotificationHandlers { error(data: ws.IError) { log.warn({ err: data }, "server error"); } -}; +} diff --git a/client/state/AppState.ts b/client/state/AppState.ts index e8d3d33..f0a67f5 100644 --- a/client/state/AppState.ts +++ b/client/state/AppState.ts @@ -1,5 +1,5 @@ import { createBrowserHistory, History } from "history"; -import { computed, configure } from "mobx"; +import { computed, configure, when } from "mobx"; import { RouterStore, syncHistoryWithStore } from "mobx-react-router"; import { WebSocketRpcClient } from "@client/sprinklersRpc/WebSocketRpcClient"; @@ -8,16 +8,37 @@ import { UiStore } from "@client/state/UiStore"; import { UserStore } from "@client/state/UserStore"; import ApiError from "@common/ApiError"; import { ErrorCode } from "@common/ErrorCode"; +import { IUser } from "@common/httpApi"; import log from "@common/logger"; +import { TypedEventEmitter, DefaultEvents } from "@common/TypedEventEmitter"; -export default class AppState { +interface AppEvents extends DefaultEvents { + checkToken(): void; + hasToken(): void; +} + +export default class AppState extends TypedEventEmitter { history: History = createBrowserHistory(); routerStore = new RouterStore(); uiStore = new UiStore(); userStore = new UserStore(); httpApi = new HttpApi(); tokenStore = this.httpApi.tokenStore; - sprinklersRpc = new WebSocketRpcClient(this.tokenStore, this.userStore); + sprinklersRpc = new WebSocketRpcClient(this.tokenStore); + + constructor() { + super(); + this.sprinklersRpc.on("newUserData", this.userStore.receiveUserData); + this.sprinklersRpc.on("tokenError", this.checkToken); + this.httpApi.on("tokenError", this.checkToken); + + this.on("checkToken", this.doCheckToken); + + this.on("hasToken", () => { + when(() => !this.tokenStore.accessToken.isValid, this.checkToken); + this.sprinklersRpc.start(); + }); + } @computed get isLoggedIn() { return this.tokenStore.accessToken.isValid; @@ -29,15 +50,20 @@ export default class AppState { }); syncHistoryWithStore(this.history, this.routerStore); - this.tokenStore.loadLocalStorage(); + await this.tokenStore.loadLocalStorage(); await this.checkToken(); - await this.sprinklersRpc.start(); } - async checkToken() { - const { tokenStore: { accessToken, refreshToken } } = this.httpApi; + checkToken = () => { + this.emit("checkToken"); + } + + private doCheckToken = async () => { + const { accessToken, refreshToken } = this.tokenStore; + accessToken.updateCurrentTime(); if (accessToken.isValid) { // if the access token is valid, we are good + this.emit("hasToken"); return; } if (!refreshToken.isValid) { // if the refresh token is not valid, need to login again @@ -46,6 +72,7 @@ export default class AppState { } try { await this.httpApi.grantRefresh(); + this.emit("hasToken"); } catch (err) { if (err instanceof ApiError && err.code === ErrorCode.BadToken) { log.warn({ err }, "refresh is bad for some reason, erasing"); diff --git a/client/state/HttpApi.ts b/client/state/HttpApi.ts index b2b7b1d..d5715d4 100644 --- a/client/state/HttpApi.ts +++ b/client/state/HttpApi.ts @@ -3,11 +3,17 @@ import ApiError from "@common/ApiError"; import { ErrorCode } from "@common/ErrorCode"; import { TokenGrantPasswordRequest, TokenGrantRefreshRequest, TokenGrantResponse } from "@common/httpApi"; import log from "@common/logger"; +import { DefaultEvents, TypedEventEmitter } from "@common/TypedEventEmitter"; import { runInAction } from "mobx"; export { ApiError }; -export default class HttpApi { +interface HttpApiEvents extends DefaultEvents { + error(err: ApiError): void; + tokenError(err: ApiError): void; +} + +export default class HttpApi extends TypedEventEmitter { baseUrl: string; tokenStore: TokenStore; @@ -20,36 +26,53 @@ export default class HttpApi { } constructor(baseUrl: string = `${location.protocol}//${location.hostname}:${location.port}/api`) { + super(); while (baseUrl.charAt(baseUrl.length - 1) === "/") { baseUrl = baseUrl.substring(0, baseUrl.length - 1); } this.baseUrl = baseUrl; this.tokenStore = new TokenStore(); + + this.on("error", (err: ApiError) => { + if (err.code === ErrorCode.BadToken) { + this.emit("tokenError", err); + } + }); } async makeRequest(url: string, options?: RequestInit, body?: any): Promise { - options = options || {}; - options = { - headers: { - "Content-Type": "application/json", - ...this.authorizationHeader, - ...options.headers || {}, - }, - body: JSON.stringify(body), - ...options, - }; - const response = await fetch(this.baseUrl + url, options); - let responseBody: any; try { - responseBody = await response.json() || {}; - } catch (e) { - throw new ApiError("Invalid JSON response", ErrorCode.Internal, e); - } - if (!response.ok) { - throw new ApiError(responseBody.message || response.statusText, responseBody.code, responseBody.data); + options = options || {}; + options = { + headers: { + "Content-Type": "application/json", + ...this.authorizationHeader, + ...options.headers || {}, + }, + body: JSON.stringify(body), + ...options, + }; + let response: Response; + try { + response = await fetch(this.baseUrl + url, options); + } catch (err) { + throw new ApiError("Http request error", ErrorCode.Internal, err); + } + let responseBody: any; + try { + responseBody = await response.json() || {}; + } catch (e) { + throw new ApiError("Invalid JSON response", ErrorCode.Internal, e); + } + if (!response.ok) { + throw new ApiError(responseBody.message || response.statusText, responseBody.code, responseBody.data); + } + return responseBody; + } catch (err) { + this.emit("error", err); + throw err; } - return responseBody; } async grantPassword(username: string, password: string) { diff --git a/client/state/Token.ts b/client/state/Token.ts index f980ee8..3c04932 100644 --- a/client/state/Token.ts +++ b/client/state/Token.ts @@ -27,7 +27,7 @@ export class Token { return this.token; } - private updateCurrentTime = (reportChanged: boolean = true) => { + updateCurrentTime = (reportChanged: boolean = true) => { if (reportChanged) { this.isExpiredAtom.reportChanged(); } diff --git a/client/state/UserStore.ts b/client/state/UserStore.ts index ba83e70..4f835cd 100644 --- a/client/state/UserStore.ts +++ b/client/state/UserStore.ts @@ -4,7 +4,7 @@ import { action, observable } from "mobx"; export class UserStore { @observable userData: IUser | null = null; - @action + @action.bound receiveUserData(userData: IUser) { this.userData = userData; } diff --git a/common/TypedEventEmitter.ts b/common/TypedEventEmitter.ts new file mode 100644 index 0000000..a67f957 --- /dev/null +++ b/common/TypedEventEmitter.ts @@ -0,0 +1,48 @@ +import { EventEmitter } from "events"; + +type TEventName = string | symbol; + +type AnyListener = (...args: any[]) => void; + +type Arguments = TListener extends (...args: infer TArgs) => any ? TArgs : any[]; +type Listener = TEvents[TEvent] extends (...args: infer TArgs) => any ? + (...args: TArgs) => void : AnyListener; + +export interface DefaultEvents { + newListener: (event: TEventName, listener: AnyListener) => void; + removeListener: (event: TEventName, listener: AnyListener) => void; +} + +export type AnyEvents = DefaultEvents & { + [event in TEventName]: any[]; +}; + +type IEventSubscriber = + (event: TEvent, listener: Listener) => This; + +// tslint:disable:ban-types + +interface ITypedEventEmitter { + on: IEventSubscriber; + off: IEventSubscriber; + once: IEventSubscriber; + addListener: IEventSubscriber; + removeListener: IEventSubscriber; + prependListener: IEventSubscriber; + prependOnceListener: IEventSubscriber; + + emit(event: TEvent, ...args: Arguments): boolean; + listeners(event: TEvent): Function[]; + rawListeners(event: TEvent): Function[]; + eventNames(): Array; + setMaxListeners(maxListeners: number): this; + getMaxListeners(): number; + listenerCount(event: TEvent): number; +} + +const TypedEventEmitter = EventEmitter as { + new(): TypedEventEmitter, +}; +type TypedEventEmitter = ITypedEventEmitter; + +export { TypedEventEmitter }; diff --git a/server/Database.ts b/server/Database.ts index 03d9ae3..314c3ff 100644 --- a/server/Database.ts +++ b/server/Database.ts @@ -45,6 +45,7 @@ export class Database { } async insertData() { + this.conn.subscribers const NUM = 100; const users: User[] = []; for (let i = 0; i < NUM; i++) {