From aacff33812a9fe310f7c8c5e8fa3034f24707003 Mon Sep 17 00:00:00 2001 From: Marco Fucci di Napoli Date: Thu, 8 Sep 2022 17:37:17 -1000 Subject: [PATCH 1/4] Part 1: make Attributes aware of their path --- src/Main.ts | 8 ++--- src/interaction/cluster/BasicCluster.ts | 18 +++++----- src/interaction/cluster/DescriptorCluster.ts | 34 ++++++++++--------- .../cluster/GeneralCommissioningCluster.ts | 31 +++++++++-------- src/interaction/cluster/OnOffCluster.ts | 29 +++++++++------- .../cluster/OperationalCredentialsCluster.ts | 21 ++++++------ src/interaction/model/Attribute.ts | 2 ++ src/interaction/model/Cluster.ts | 21 ++++++------ src/interaction/model/Endpoint.ts | 9 +++-- 9 files changed, 93 insertions(+), 80 deletions(-) diff --git a/src/Main.ts b/src/Main.ts index 379109e9..7683dde8 100644 --- a/src/Main.ts +++ b/src/Main.ts @@ -65,12 +65,12 @@ class Main { )) .addProtocolHandler(Protocol.INTERACTION_MODEL, new InteractionProtocol(new Device([ new Endpoint(0x00, DEVICE.ROOT, [ - new BasicCluster({ vendorName, vendorId, productName, productId }), - new GeneralCommissioningCluster(), - new OperationalCredentialsCluster({devicePrivateKey: DevicePrivateKey, deviceCertificate: DeviceCertificate, deviceIntermediateCertificate: ProductIntermediateCertificate, certificateDeclaration: CertificateDeclaration}), + BasicCluster.Builder({ vendorName, vendorId, productName, productId }), + GeneralCommissioningCluster.Builder(), + OperationalCredentialsCluster.Builder({devicePrivateKey: DevicePrivateKey, deviceCertificate: DeviceCertificate, deviceIntermediateCertificate: ProductIntermediateCertificate, certificateDeclaration: CertificateDeclaration}), ]), new Endpoint(0x01, DEVICE.ON_OFF_LIGHT, [ - new OnOffCluster(commandExecutor("on"), commandExecutor("off")), + OnOffCluster.Builder(commandExecutor("on"), commandExecutor("off")), ]), ]))) .start() diff --git a/src/interaction/cluster/BasicCluster.ts b/src/interaction/cluster/BasicCluster.ts index 9d2b8979..024caca0 100644 --- a/src/interaction/cluster/BasicCluster.ts +++ b/src/interaction/cluster/BasicCluster.ts @@ -5,7 +5,6 @@ */ import { StringT, UnsignedIntT } from "../../codec/TlvObjectCodec"; -import { Attribute } from "../model/Attribute"; import { Cluster } from "../model/Cluster"; interface BasicClusterConf { @@ -16,17 +15,18 @@ interface BasicClusterConf { } export class BasicCluster extends Cluster { - constructor({ vendorName, vendorId, productName, productId }: BasicClusterConf) { + static Builder = (conf: BasicClusterConf) => (endpointId: number) => new BasicCluster(endpointId, conf); + + constructor(endpointId: number, { vendorName, vendorId, productName, productId }: BasicClusterConf) { super( + endpointId, 0x28, "Basic", - [], - [ - new Attribute(1, "VendorName", StringT, vendorName), - new Attribute(2, "VendorID", UnsignedIntT, vendorId), - new Attribute(3, "ProductName", StringT, productName), - new Attribute(4, "ProductID", UnsignedIntT, productId), - ], ); + + this.addAttribute(1, "VendorName", StringT, vendorName); + this.addAttribute(2, "VendorID", UnsignedIntT, vendorId); + this.addAttribute(3, "ProductName", StringT, productName); + this.addAttribute(4, "ProductID", UnsignedIntT, productId); } } diff --git a/src/interaction/cluster/DescriptorCluster.ts b/src/interaction/cluster/DescriptorCluster.ts index b185ddb8..6bfa8f03 100644 --- a/src/interaction/cluster/DescriptorCluster.ts +++ b/src/interaction/cluster/DescriptorCluster.ts @@ -5,30 +5,32 @@ */ import { ArrayT, Field, ObjectT, UnsignedIntT } from "../../codec/TlvObjectCodec"; -import { Attribute } from "../model/Attribute"; import { Cluster } from "../model/Cluster"; import { Endpoint } from "../model/Endpoint"; const CLUSTER_ID = 0x1d; export class DescriptorCluster extends Cluster { - constructor(endpoint: Endpoint, allEndpoints: Endpoint[]) { + static Builder = (allEndpoints: Endpoint[]) => (endpointId: number) => new DescriptorCluster(endpointId, allEndpoints); + + constructor(endpointId: number, allEndpoints: Endpoint[]) { super( - CLUSTER_ID, + endpointId, + 0x1d, "Descriptor", - [], - [ - new Attribute(0, "DeviceList", ArrayT(ObjectT({ - type: Field(0, UnsignedIntT), - revision: Field(1, UnsignedIntT), - })), [{ - type: endpoint.device.code, - revision: 1, - }]), - new Attribute(1, "ServerList", ArrayT(UnsignedIntT), [CLUSTER_ID, ...endpoint.getClusterIds()]), - new Attribute(3, "ClientList", ArrayT(UnsignedIntT), []), - new Attribute(4, "PartsList", ArrayT(UnsignedIntT), endpoint.id === 0 ? allEndpoints.map(endpoint => endpoint.id).filter(endpointId => endpointId !== 0) : []), - ], ); + const endpoint = allEndpoints.find(endpoint => endpoint.id === endpointId); + if (endpoint === undefined) throw new Error(`Endpoint with id ${endpointId} doesn't exist`); + + this.addAttribute(0, "DeviceList", ArrayT(ObjectT({ + type: Field(0, UnsignedIntT), + revision: Field(1, UnsignedIntT), + })), [{ + type: endpoint.device.code, + revision: 1, + }]); + this.addAttribute(1, "ServerList", ArrayT(UnsignedIntT), [CLUSTER_ID, ...endpoint.getClusterIds()]); + this.addAttribute(3, "ClientList", ArrayT(UnsignedIntT), []); + this.addAttribute(4, "PartsList", ArrayT(UnsignedIntT), endpointId === 0 ? allEndpoints.map(endpoint => endpoint.id).filter(endpointId => endpointId !== 0) : []); } } diff --git a/src/interaction/cluster/GeneralCommissioningCluster.ts b/src/interaction/cluster/GeneralCommissioningCluster.ts index fce487eb..fdd3b4eb 100644 --- a/src/interaction/cluster/GeneralCommissioningCluster.ts +++ b/src/interaction/cluster/GeneralCommissioningCluster.ts @@ -5,9 +5,8 @@ */ import { Cluster } from "../model/Cluster"; -import { Attribute } from "../model/Attribute"; import { Field, JsType, ObjectT, StringT, UnsignedIntT } from "../../codec/TlvObjectCodec"; -import { Command, NoArgumentsT } from "../model/Command"; +import { NoArgumentsT } from "../model/Command"; const enum RegulatoryLocationType { Indoor = 0, @@ -50,25 +49,27 @@ type SuccessFailureReponse = JsType; const SuccessResponse = {errorCode: CommissioningError.Ok, debugText: ""}; export class GeneralCommissioningCluster extends Cluster { - private readonly attributes = { - breadcrumb: new Attribute(0, "Breadcrumb", UnsignedIntT, 0), - comminssioningInfo: new Attribute(1, "BasicCommissioningInfo", BasicCommissioningInfoT, {failSafeExpiryLengthSeconds: 60 /* 1mn */}), - regulatoryConfig: new Attribute(2, "RegulatoryConfig", UnsignedIntT, RegulatoryLocationType.Indoor), - locationCapability: new Attribute(3, "LocationCapability", UnsignedIntT, RegulatoryLocationType.IndoorOutdoor), - } + static Builder = () => (endpointId: number) => new GeneralCommissioningCluster(endpointId); + + private readonly attributes; - constructor() { + constructor(endpointId: number) { super( + endpointId, 0x30, "General Commissioning", - [ - new Command(0, 1, "ArmFailSafe", ArmFailSafeRequestT, SuccessFailureReponseT, request => this.handleArmFailSafeRequest(request)), - new Command(2, 3, "SetRegulatoryConfig", SetRegulatoryConfigRequestT, SuccessFailureReponseT, request => this.setRegulatoryConfig(request)), - new Command(4, 5, "CommissioningComplete", NoArgumentsT, SuccessFailureReponseT, () => this.handleCommissioningComplete()), - ], ); - this.addAttributes(Object.values(this.attributes)); + this.addCommand(0, 1, "ArmFailSafe", ArmFailSafeRequestT, SuccessFailureReponseT, request => this.handleArmFailSafeRequest(request)); + this.addCommand(2, 3, "SetRegulatoryConfig", SetRegulatoryConfigRequestT, SuccessFailureReponseT, request => this.setRegulatoryConfig(request)); + this.addCommand(4, 5, "CommissioningComplete", NoArgumentsT, SuccessFailureReponseT, () => this.handleCommissioningComplete()); + + this.attributes = { + breadcrumb: this.addAttribute(0, "Breadcrumb", UnsignedIntT, 0), + comminssioningInfo: this.addAttribute(1, "BasicCommissioningInfo", BasicCommissioningInfoT, {failSafeExpiryLengthSeconds: 60 /* 1mn */}), + regulatoryConfig: this.addAttribute(2, "RegulatoryConfig", UnsignedIntT, RegulatoryLocationType.Indoor), + locationCapability: this.addAttribute(3, "LocationCapability", UnsignedIntT, RegulatoryLocationType.IndoorOutdoor), + }; } private handleArmFailSafeRequest({breadcrumb}: ArmFailSafeRequest): SuccessFailureReponse { diff --git a/src/interaction/cluster/OnOffCluster.ts b/src/interaction/cluster/OnOffCluster.ts index 9be87018..a6047c5f 100644 --- a/src/interaction/cluster/OnOffCluster.ts +++ b/src/interaction/cluster/OnOffCluster.ts @@ -5,29 +5,32 @@ */ import { Cluster } from "../model/Cluster"; -import { Attribute } from "../model/Attribute"; import { BooleanT } from "../../codec/TlvObjectCodec"; -import { Command, NoArgumentsT, NoResponseT } from "../model/Command"; +import { NoArgumentsT, NoResponseT } from "../model/Command"; + +const CLUSTER_ID = 0x06; export class OnOffCluster extends Cluster { - private onOffAttribute = new Attribute(0, "OnOff", BooleanT, false); + static Builder = (onCallback: (() => void) | undefined, offCallback: (() => void) | undefined) => (endpointId: number) => new OnOffCluster(endpointId, onCallback, offCallback); + + private readonly onOffAttribute; constructor( - private readonly onCallback: (() => void) | undefined = undefined, - private readonly offCallback: (() => void) | undefined = undefined, + endpointId: number, + private readonly onCallback: (() => void) | undefined, + private readonly offCallback: (() => void) | undefined, ) { super( + endpointId, 0x06, "On/Off", - [ - new Command(0, 0, "Off", NoArgumentsT, NoResponseT, () => this.setOnOff(false)), - new Command(1, 1, "On", NoArgumentsT, NoResponseT, () => this.setOnOff(true)), - new Command(2, 2, "Toggle", NoArgumentsT, NoResponseT, () => this.setOnOff(!this.onOffAttribute.get())), - ], ); - this.addAttributes([ - this.onOffAttribute, - ]); + + this.addCommand(0, 0, "Off", NoArgumentsT, NoResponseT, () => this.setOnOff(false)), + this.addCommand(1, 1, "On", NoArgumentsT, NoResponseT, () => this.setOnOff(true)), + this.addCommand(2, 2, "Toggle", NoArgumentsT, NoResponseT, () => this.setOnOff(!this.onOffAttribute.get())), + + this.onOffAttribute = this.addAttribute(0, "OnOff", BooleanT, false); } private setOnOff(value: boolean) { diff --git a/src/interaction/cluster/OperationalCredentialsCluster.ts b/src/interaction/cluster/OperationalCredentialsCluster.ts index 3c2516f3..c5881543 100644 --- a/src/interaction/cluster/OperationalCredentialsCluster.ts +++ b/src/interaction/cluster/OperationalCredentialsCluster.ts @@ -8,7 +8,7 @@ import { TlvObjectCodec } from "../../codec/TlvObjectCodec"; import { Crypto } from "../../crypto/Crypto"; import { FabricBuilder } from "../../fabric/Fabric"; import { Cluster } from "../model/Cluster"; -import { Command, NoResponseT } from "../model/Command"; +import { NoResponseT } from "../model/Command"; import { Session } from "../../session/Session"; import { AddNocRequestT, AddTrustedRootCertificateRequestT, AttestationResponseT, AttestationT, CertificateChainRequestT, CertificateChainResponseT, CertificateSigningRequestT, CertificateType, CsrResponseT, RequestWithNonceT, Status, StatusResponseT } from "./OperationalCredentialsMessages"; @@ -20,21 +20,22 @@ interface OperationalCredentialsClusterConf { } export class OperationalCredentialsCluster extends Cluster { + static Builder = (conf: OperationalCredentialsClusterConf) => (endpointId: number) => new OperationalCredentialsCluster(endpointId, conf); + private fabricBuilder?: FabricBuilder; - constructor(private readonly conf: OperationalCredentialsClusterConf) { + constructor(endpointId: number, private readonly conf: OperationalCredentialsClusterConf) { super( + endpointId, 0x3e, "Operational Credentials", - [ - new Command(0, 1, "AttestationRequest", RequestWithNonceT, AttestationResponseT, ({nonce}, session) => this.handleAttestationRequest(nonce, session)), - new Command(2, 3, "CertificateChainRequest", CertificateChainRequestT, CertificateChainResponseT, ({type}) => this.handleCertificateChainRequest(type)), - new Command(4, 5, "CSRRequest", RequestWithNonceT, CsrResponseT, ({nonce}, session) => this.handleCertificateSignRequest(nonce, session)), - new Command(6, 8, "AddNOC", AddNocRequestT, StatusResponseT, ({nocCert, icaCert, ipkValue, caseAdminNode, adminVendorId}, session) => this.addNewOperationalCertificates(nocCert, icaCert, ipkValue, caseAdminNode, adminVendorId, session)), - new Command(11, 11, "AddTrustedRootCertificate", AddTrustedRootCertificateRequestT, NoResponseT, ({certificate}) => this.addTrustedRootCertificate(certificate)), - ], - [], ); + + this.addCommand(0, 1, "AttestationRequest", RequestWithNonceT, AttestationResponseT, ({nonce}, session) => this.handleAttestationRequest(nonce, session)); + this.addCommand(2, 3, "CertificateChainRequest", CertificateChainRequestT, CertificateChainResponseT, ({type}) => this.handleCertificateChainRequest(type)); + this.addCommand(4, 5, "CSRRequest", RequestWithNonceT, CsrResponseT, ({nonce}, session) => this.handleCertificateSignRequest(nonce, session)); + this.addCommand(6, 8, "AddNOC", AddNocRequestT, StatusResponseT, ({nocCert, icaCert, ipkValue, caseAdminNode, adminVendorId}, session) => this.addNewOperationalCertificates(nocCert, icaCert, ipkValue, caseAdminNode, adminVendorId, session)); + this.addCommand(11, 11, "AddTrustedRootCertificate", AddTrustedRootCertificateRequestT, NoResponseT, ({certificate}) => this.addTrustedRootCertificate(certificate)); } private handleAttestationRequest(nonce: Buffer, session: Session) { diff --git a/src/interaction/model/Attribute.ts b/src/interaction/model/Attribute.ts index b1d9873b..5d2db051 100644 --- a/src/interaction/model/Attribute.ts +++ b/src/interaction/model/Attribute.ts @@ -12,6 +12,8 @@ export class Attribute { private template: Template; constructor( + readonly endpointId: number, + readonly clusterId: number, readonly id: number, readonly name: string, template: Template, diff --git a/src/interaction/model/Cluster.ts b/src/interaction/model/Cluster.ts index 2a9bf947..3fdff46a 100644 --- a/src/interaction/model/Cluster.ts +++ b/src/interaction/model/Cluster.ts @@ -5,6 +5,7 @@ */ import { Element } from "../../codec/TlvCodec"; +import { Template } from "../../codec/TlvObjectCodec"; import { Session } from "../../session/Session"; import { Attribute } from "./Attribute"; import { Command } from "./Command"; @@ -14,21 +15,21 @@ export class Cluster { private readonly commandsMap = new Map>(); constructor( + readonly endpointId: number, readonly id: number, readonly name: string, - commands?: Command[], - attributes?: Attribute[], - ) { - if (commands !== undefined) this.addCommands(commands); - if (attributes !== undefined) this.addAttributes(attributes); - } + ) {} - addCommands(commands: Command[]) { - commands.forEach(command => this.commandsMap.set(command.invokeId, command)); + addAttribute(id: number, name: string, template: Template, defaultValue: T) { + const attribute = new Attribute(this.endpointId, this.id, id, name, template, defaultValue); + this.attributesMap.set(id, attribute); + return attribute; } - addAttributes(attributes: Attribute[]) { - attributes.forEach(attribute => this.attributesMap.set(attribute.id, attribute)); + addCommand(invokeId: number, responseId: number, name: string, requestTemplate: Template, responseTemplate: Template, handler: (request: RequestT, session: Session) => Promise | ResponseT) { + const command = new Command(invokeId, responseId, this.name, requestTemplate, responseTemplate, handler); + this.commandsMap.set(invokeId, command); + return command; } getAttributeValue(attributeId?: number) { diff --git a/src/interaction/model/Endpoint.ts b/src/interaction/model/Endpoint.ts index 7adef580..426e6ba2 100644 --- a/src/interaction/model/Endpoint.ts +++ b/src/interaction/model/Endpoint.ts @@ -15,13 +15,16 @@ export class Endpoint { constructor( readonly id: number, readonly device: {name: string, code: number}, - clusters: Cluster[], + clusterBuilders: ((endpointId: number) => Cluster)[], ) { - clusters.forEach(cluster => this.clustersMap.set(cluster.id, cluster)); + clusterBuilders.forEach(clusterBuilder => { + const cluster = clusterBuilder(id); + this.clustersMap.set(cluster.id, cluster); + }); } addDescriptorCluster(endpoints: Endpoint[]) { - const descriptorCluster = new DescriptorCluster(this, endpoints); + const descriptorCluster = DescriptorCluster.Builder(endpoints)(this.id); this.clustersMap.set(descriptorCluster.id, descriptorCluster); } From 550d20b257a4a631217d24e296ce597a35a897ca Mon Sep 17 00:00:00 2001 From: Marco Fucci di Napoli Date: Fri, 9 Sep 2022 22:20:09 -1000 Subject: [PATCH 2/4] Refactoring part 2 --- src/interaction/InteractionProtocol.ts | 6 +++--- src/interaction/model/Attribute.ts | 5 +++++ src/interaction/model/Cluster.ts | 17 +++++++++-------- src/interaction/model/Device.ts | 21 ++++++++++----------- src/interaction/model/Endpoint.ts | 19 ++++++++++--------- 5 files changed, 37 insertions(+), 31 deletions(-) diff --git a/src/interaction/InteractionProtocol.ts b/src/interaction/InteractionProtocol.ts index 75460c1d..441dbfbd 100644 --- a/src/interaction/InteractionProtocol.ts +++ b/src/interaction/InteractionProtocol.ts @@ -24,13 +24,13 @@ export class InteractionProtocol implements ProtocolHandler { await messenger.handleRequest(); } - handleReadRequest(exchange: MessageExchange, {attributes}: ReadRequest): ReadResponse { - console.log(`Received read request from ${exchange.channel.getName()}: ${attributes.map(({endpointId = "*", clusterId = "*", attributeId = "*"}) => `${endpointId}/${clusterId}/${attributeId}`).join(", ")}`); + handleReadRequest(exchange: MessageExchange, {attributes: attributePaths}: ReadRequest): ReadResponse { + console.log(`Received read request from ${exchange.channel.getName()}: ${attributePaths.map(({endpointId = "*", clusterId = "*", attributeId = "*"}) => `${endpointId}/${clusterId}/${attributeId}`).join(", ")}`); return { isFabricFiltered: true, interactionModelRevision: 1, - values: attributes.flatMap(path => this.device.getAttributeValues(path)).map(value => ({value})), + values: attributePaths.flatMap(path => this.device.getAttributes(path)).map(attribute => ({ value: attribute.getValue() })), }; } diff --git a/src/interaction/model/Attribute.ts b/src/interaction/model/Attribute.ts index 5d2db051..d7df6c5d 100644 --- a/src/interaction/model/Attribute.ts +++ b/src/interaction/model/Attribute.ts @@ -34,6 +34,11 @@ export class Attribute { getValue() { return { + path: { + endpointId: this.endpointId, + clusterId: this.clusterId, + attributeId: this.id, + }, version: this.version, value: TlvObjectCodec.encodeElement(this.value, this.template), } diff --git a/src/interaction/model/Cluster.ts b/src/interaction/model/Cluster.ts index 3fdff46a..a9f661fd 100644 --- a/src/interaction/model/Cluster.ts +++ b/src/interaction/model/Cluster.ts @@ -31,15 +31,16 @@ export class Cluster { this.commandsMap.set(invokeId, command); return command; } - - getAttributeValue(attributeId?: number) { - // If attributeId is not provided, iterate over all attributes - var attributeIds = (attributeId === undefined) ? [...this.attributesMap.keys()] : [ attributeId ]; - return attributeIds.flatMap(attributeId => { - const valueVersion = this.attributesMap.get(attributeId)?.getValue(); - return (valueVersion === undefined) ? [] : [{attributeId, ...valueVersion}]; - }) + getAttributes(attributeId?: number): Attribute[] { + if (attributeId === undefined) { + // If the attributeId is not provided, return all attributes + return [...this.attributesMap.values()]; + } + + const attribute = this.attributesMap.get(attributeId); + if (attribute === undefined) return []; + return [attribute]; } async invoke(session: Session, commandId: number, args: Element) { diff --git a/src/interaction/model/Device.ts b/src/interaction/model/Device.ts index 42d50cff..38233dbd 100644 --- a/src/interaction/model/Device.ts +++ b/src/interaction/model/Device.ts @@ -6,6 +6,7 @@ import { Element } from "../../codec/TlvCodec"; import { Session } from "../../session/Session"; +import { Attribute } from "./Attribute"; import { Endpoint } from "./Endpoint"; interface AttributePath { @@ -30,17 +31,15 @@ export class Device { }); } - getAttributeValues({endpointId, clusterId, attributeId}: AttributePath) { - // If the endpoint is not provided, iterate over all endpoints - var endpointIds = (endpointId === undefined) ? [...this.endpointsMap.keys()] : [ endpointId ]; - return endpointIds.flatMap(endpointId => { - const values = this.endpointsMap.get(endpointId)?.getAttributeValue(clusterId, attributeId); - if (values === undefined) return []; - return values.map(({clusterId, attributeId, value, version}) => ({ - path: { endpointId, clusterId, attributeId }, - value, version - })); - }) + getAttributes({endpointId, clusterId, attributeId}: AttributePath): Attribute[] { + if (endpointId === undefined) { + // If the endpoint is not provided, iterate over all endpoints + return [...this.endpointsMap.values()].flatMap(endpoint => endpoint.getAttributes(clusterId, attributeId)); + } + + const endpoint = this.endpointsMap.get(endpointId); + if (endpoint === undefined) return []; + return endpoint.getAttributes(clusterId, attributeId); } async invoke(session: Session, commandPath: CommandPath, args: Element) { diff --git a/src/interaction/model/Endpoint.ts b/src/interaction/model/Endpoint.ts index 426e6ba2..ecf6ad33 100644 --- a/src/interaction/model/Endpoint.ts +++ b/src/interaction/model/Endpoint.ts @@ -7,6 +7,7 @@ import { Element } from "../../codec/TlvCodec"; import { Session } from "../../session/Session"; import { DescriptorCluster } from "../cluster/DescriptorCluster"; +import { Attribute } from "./Attribute"; import { Cluster } from "./Cluster"; export class Endpoint { @@ -28,15 +29,15 @@ export class Endpoint { this.clustersMap.set(descriptorCluster.id, descriptorCluster); } - getAttributeValue(clusterId?: number, attributeId?: number) { - // If clusterId is not provided, iterate over all clusters - var clusterIds = (clusterId === undefined) ? [...this.clustersMap.keys()] : [ clusterId ]; - - return clusterIds.flatMap(clusterId => { - const values = this.clustersMap.get(clusterId)?.getAttributeValue(attributeId); - if (values === undefined) return []; - return values.map(value => ({clusterId, ...value})); - }) + getAttributes(clusterId?: number, attributeId?: number): Attribute[] { + if (clusterId === undefined) { + // If the clusterId is not provided, iterate over all clusters + return [...this.clustersMap.values()].flatMap(cluster => cluster.getAttributes(attributeId)); + } + + const cluster = this.clustersMap.get(clusterId); + if (cluster === undefined) return []; + return cluster.getAttributes(attributeId); } getClusterIds() { From 9013978dd9e2bb0165829b06005e9a9c8aabddb2 Mon Sep 17 00:00:00 2001 From: Marco Fucci di Napoli Date: Mon, 12 Sep 2022 12:54:23 -1000 Subject: [PATCH 3/4] Implement attribute subscription --- src/interaction/InteractionMessages.ts | 9 ++- src/interaction/InteractionMessenger.ts | 71 +++++++++++------ src/interaction/InteractionProtocol.ts | 72 +++++++++++++---- src/interaction/model/Attribute.ts | 32 +++++++- src/server/MatterServer.ts | 32 ++++++-- src/server/MessageExchange.ts | 89 ++++++++++++++-------- src/session/SecureSession.ts | 47 ++++++++++-- src/session/Session.ts | 5 ++ src/session/UnsecureSession.ts | 21 +++++ src/session/secure/SecureChannelHandler.ts | 11 ++- 10 files changed, 299 insertions(+), 90 deletions(-) diff --git a/src/interaction/InteractionMessages.ts b/src/interaction/InteractionMessages.ts index 1bbd130e..bbb1e222 100644 --- a/src/interaction/InteractionMessages.ts +++ b/src/interaction/InteractionMessages.ts @@ -7,6 +7,10 @@ import { TlvType } from "../codec/TlvCodec"; import { AnyT, ArrayT, BooleanT, Field, ObjectT, OptionalField, UnsignedIntT, UnsignedLongT } from "../codec/TlvObjectCodec"; +export const StatusReport = ObjectT({ + status: OptionalField(0, UnsignedIntT), +}); + const AttributePathT = ObjectT({ endpointId: OptionalField(2, UnsignedIntT), clusterId: OptionalField(3, UnsignedIntT), @@ -19,7 +23,8 @@ export const ReadRequestT = ObjectT({ interactionModelRevision: Field(0xFF, UnsignedIntT), }); -export const ReadResponseT = ObjectT({ +export const DataReportT = ObjectT({ + subscriptionId: OptionalField(0, UnsignedIntT), values: Field(1, ArrayT(ObjectT({ value: Field(1, ObjectT({ version: Field(0, UnsignedIntT), @@ -64,7 +69,7 @@ export const SubscribeRequestT = ObjectT({ export const SubscribeResponseT = ObjectT({ subscriptionId: Field(0, UnsignedIntT), - minIntervalFloorSeconds: Field(1, UnsignedIntT), + minIntervalFloorSeconds: OptionalField(1, UnsignedIntT), maxIntervalCeilingSeconds: Field(2, UnsignedIntT), }); diff --git a/src/interaction/InteractionMessenger.ts b/src/interaction/InteractionMessenger.ts index ba80125c..bf851401 100644 --- a/src/interaction/InteractionMessenger.ts +++ b/src/interaction/InteractionMessenger.ts @@ -6,7 +6,13 @@ import { JsType, TlvObjectCodec } from "../codec/TlvObjectCodec"; import { MessageExchange } from "../server/MessageExchange"; -import { InvokeRequestT, InvokeResponseT, ReadRequestT, ReadResponseT, SubscribeRequestT, SubscribeResponseT } from "./InteractionMessages"; +import { StatusResponseT } from "./cluster/OperationalCredentialsMessages"; +import { InvokeRequestT, InvokeResponseT, ReadRequestT, DataReportT, SubscribeRequestT, SubscribeResponseT } from "./InteractionMessages"; + +export const enum Status { + Success = 0x00, + Failure = 0x01, +} export const enum MessageType { StatusResponse = 0x01, @@ -22,7 +28,7 @@ export const enum MessageType { } export type ReadRequest = JsType; -export type ReadResponse = JsType; +export type DataReport = JsType; export type SubscribeRequest = JsType; export type SubscribeResponse = JsType; export type InvokeRequest = JsType; @@ -32,31 +38,48 @@ export class InteractionMessenger { constructor( private readonly exchange: MessageExchange, - private readonly handleReadRequest: (request: ReadRequest) => ReadResponse, - private readonly handleSubscribeRequest: (request: SubscribeRequest) => SubscribeResponse, - private readonly handleInvokeRequest: (request: InvokeRequest) => Promise, ) {} - async handleRequest() { + async handleRequest( + handleReadRequest: (request: ReadRequest) => DataReport, + handleSubscribeRequest: (request: SubscribeRequest) => SubscribeResponse | undefined, + handleInvokeRequest: (request: InvokeRequest) => Promise, + ) { const message = await this.exchange.nextMessage(); - switch (message.payloadHeader.messageType) { - case MessageType.ReadRequest: - const readRequest = TlvObjectCodec.decode(message.payload, ReadRequestT); - const readResponse = this.handleReadRequest(readRequest); - this.exchange.send(MessageType.ReportData, TlvObjectCodec.encode(readResponse, ReadResponseT)); - break; - case MessageType.SubscribeRequest: - const subscribeRequest = TlvObjectCodec.decode(message.payload, SubscribeRequestT); - const subscribeResponse = this.handleSubscribeRequest(subscribeRequest); - this.exchange.send(MessageType.SubscribeResponse, TlvObjectCodec.encode(subscribeResponse, SubscribeResponseT)); - break; - case MessageType.InvokeCommandRequest: - const invokeRequest = TlvObjectCodec.decode(message.payload, InvokeRequestT); - const invokeResponse = await this.handleInvokeRequest(invokeRequest); - this.exchange.send(MessageType.InvokeCommandResponse, TlvObjectCodec.encode(invokeResponse, InvokeResponseT)); - break; - default: - throw new Error(`Unsupported message type ${message.payloadHeader.messageType}`); + try { + switch (message.payloadHeader.messageType) { + case MessageType.ReadRequest: + const readRequest = TlvObjectCodec.decode(message.payload, ReadRequestT); + this.sendDataReport(handleReadRequest(readRequest)); + break; + case MessageType.SubscribeRequest: + const subscribeRequest = TlvObjectCodec.decode(message.payload, SubscribeRequestT); + const subscribeResponse = handleSubscribeRequest(subscribeRequest); + if (subscribeRequest === undefined) { + this.sendStatus(Status.Success); + } else { + this.exchange.send(MessageType.SubscribeResponse, TlvObjectCodec.encode(subscribeResponse, SubscribeResponseT)); + } + break; + case MessageType.InvokeCommandRequest: + const invokeRequest = TlvObjectCodec.decode(message.payload, InvokeRequestT); + const invokeResponse = await handleInvokeRequest(invokeRequest); + this.exchange.send(MessageType.InvokeCommandResponse, TlvObjectCodec.encode(invokeResponse, InvokeResponseT)); + break; + default: + throw new Error(`Unsupported message type ${message.payloadHeader.messageType}`); + } + } catch (error) { + console.error(error); + this.sendStatus(Status.Failure); } } + + sendDataReport(dataReport: DataReport) { + this.exchange.send(MessageType.ReportData, TlvObjectCodec.encode(dataReport, DataReportT)); + } + + private sendStatus(status: Status) { + this.exchange.send(MessageType.StatusResponse, TlvObjectCodec.encode({status}, StatusResponseT)); + } } diff --git a/src/interaction/InteractionProtocol.ts b/src/interaction/InteractionProtocol.ts index 441dbfbd..691c8329 100644 --- a/src/interaction/InteractionProtocol.ts +++ b/src/interaction/InteractionProtocol.ts @@ -5,9 +5,12 @@ */ import { Device } from "./model/Device"; -import { ProtocolHandler } from "../server/MatterServer"; +import { ExchangeSocket, MatterServer, Protocol, ProtocolHandler } from "../server/MatterServer"; import { MessageExchange } from "../server/MessageExchange"; -import { InteractionMessenger, InvokeRequest, InvokeResponse, ReadRequest, ReadResponse, SubscribeRequest, SubscribeResponse } from "./InteractionMessenger"; +import { InteractionMessenger, InvokeRequest, InvokeResponse, ReadRequest, DataReport, SubscribeRequest, SubscribeResponse } from "./InteractionMessenger"; +import { SecureSession } from "../session/SecureSession"; +import { Attribute, Report } from "./model/Attribute"; +import { Session } from "../session/Session"; export class InteractionProtocol implements ProtocolHandler { constructor( @@ -15,16 +18,14 @@ export class InteractionProtocol implements ProtocolHandler { ) {} async onNewExchange(exchange: MessageExchange) { - const messenger = new InteractionMessenger( - exchange, + await new InteractionMessenger(exchange).handleRequest( readRequest => this.handleReadRequest(exchange, readRequest), subscribeRequest => this.handleSubscribeRequest(exchange, subscribeRequest), invokeRequest => this.handleInvokeRequest(exchange, invokeRequest), ); - await messenger.handleRequest(); } - handleReadRequest(exchange: MessageExchange, {attributes: attributePaths}: ReadRequest): ReadResponse { + handleReadRequest(exchange: MessageExchange, {attributes: attributePaths}: ReadRequest): DataReport { console.log(`Received read request from ${exchange.channel.getName()}: ${attributePaths.map(({endpointId = "*", clusterId = "*", attributeId = "*"}) => `${endpointId}/${clusterId}/${attributeId}`).join(", ")}`); return { @@ -34,16 +35,28 @@ export class InteractionProtocol implements ProtocolHandler { }; } - handleSubscribeRequest(exchange: MessageExchange, { minIntervalFloorSeconds, maxIntervalCeilingSeconds }: SubscribeRequest): SubscribeResponse { + handleSubscribeRequest(exchange: MessageExchange, { minIntervalFloorSeconds, maxIntervalCeilingSeconds, attributeRequests, keepSubscriptions }: SubscribeRequest): SubscribeResponse | undefined { console.log(`Received subscribe request from ${exchange.channel.getName()}`); - // TODO: implement this + if (!exchange.session.isSecure()) throw new Error("Subscriptions are only implemented on secure sessions"); - return { - subscriptionId: 0, - minIntervalFloorSeconds, - maxIntervalCeilingSeconds, - }; + const session = exchange.session as SecureSession; + + if (!keepSubscriptions) { + session.clearSubscriptions(); + } + + if (attributeRequests !== undefined) { + const attributes = attributeRequests.flatMap(path => this.device.getAttributes(path)); + + if (attributeRequests.length === 0) throw new Error("Invalid subscription request"); + + return { + subscriptionId: session.addSubscription(SubscriptionHandler.Builder(session, exchange.channel.channel, session.getServer(), attributes)), + minIntervalFloorSeconds, + maxIntervalCeilingSeconds, + }; + } } async handleInvokeRequest(exchange: MessageExchange, {invokes}: InvokeRequest): Promise { @@ -63,3 +76,36 @@ export class InteractionProtocol implements ProtocolHandler { }; } } + +export class SubscriptionHandler { + + static Builder = (session: Session, channel: ExchangeSocket, server: MatterServer, attributes: Attribute[]) => (subscriptionId: number) => new SubscriptionHandler(subscriptionId, session, channel, server, attributes); + + constructor( + readonly subscriptionId: number, + private readonly session: Session, + private readonly channel: ExchangeSocket, + private readonly server: MatterServer, + private readonly attributes: Attribute[], + ) { + // TODO: implement minIntervalFloorSeconds and maxIntervalCeilingSeconds + + attributes.forEach(attribute => attribute.addSubscription(this)); + } + + sendReport(report: Report) { + // TODO: this should be sent to the last discovered address of this node instead of the one used to request the subscription + + const exchange = this.server.initiateExchange(this.session, this.channel, Protocol.INTERACTION_MODEL); + new InteractionMessenger(exchange).sendDataReport({ + subscriptionId: this.subscriptionId, + isFabricFiltered: true, + interactionModelRevision: 1, + values: [{ value: report }], + }); + } + + cancel() { + this.attributes.forEach(attribute => attribute.removeSubscription(this.subscriptionId)); + } +} diff --git a/src/interaction/model/Attribute.ts b/src/interaction/model/Attribute.ts index d7df6c5d..be829ba3 100644 --- a/src/interaction/model/Attribute.ts +++ b/src/interaction/model/Attribute.ts @@ -4,35 +4,51 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { Element } from "../../codec/TlvCodec"; import { Template, TlvObjectCodec } from "../../codec/TlvObjectCodec"; +import { SubscriptionHandler } from "../InteractionProtocol"; + +export interface Report { + path: { + endpointId: number, + clusterId: number, + attributeId: number, + }, + version: number, + value: Element | undefined, +} export class Attribute { private value: T; private version = 0; - private template: Template; + private readonly subscriptionsMap = new Map(); constructor( readonly endpointId: number, readonly clusterId: number, readonly id: number, readonly name: string, - template: Template, + private readonly template: Template, defaultValue: T, ) { this.value = defaultValue; - this.template = template; } set(value: T) { this.version++; this.value = value; + + if (this.subscriptionsMap.size !== 0) { + const report = this.getValue(); + [...this.subscriptionsMap.values()].forEach(subscription => subscription.sendReport(report)) + } } get(): T { return this.value; } - getValue() { + getValue(): Report { return { path: { endpointId: this.endpointId, @@ -43,4 +59,12 @@ export class Attribute { value: TlvObjectCodec.encodeElement(this.value, this.template), } } + + addSubscription(subscription: SubscriptionHandler) { + this.subscriptionsMap.set(subscription.subscriptionId, subscription); + } + + removeSubscription(subscriptionId: number) { + this.subscriptionsMap.delete(subscriptionId); + } } diff --git a/src/server/MatterServer.ts b/src/server/MatterServer.ts index 9fb722bc..9be7a09d 100644 --- a/src/server/MatterServer.ts +++ b/src/server/MatterServer.ts @@ -5,12 +5,12 @@ */ import { Crypto } from "../crypto/Crypto"; -import { MessageCodec, SessionType } from "../codec/MessageCodec"; +import { Message, MessageCodec, SessionType } from "../codec/MessageCodec"; import { SessionManager } from "../session/SessionManager"; import { MessageExchange } from "./MessageExchange"; -import { MdnsServer } from "../mdns/MdnsServer"; import { FabricManager } from "../fabric/FabricManager"; import { MatterMdnsServer } from "../mdns/MatterMdnsServer"; +import { Session } from "../session/Session"; export const enum Protocol { SECURE_CHANNEL = 0x0000, @@ -27,7 +27,7 @@ export interface Channel { } export interface ProtocolHandler { - onNewExchange(exchange: MessageExchange): void; + onNewExchange(exchange: MessageExchange, message: Message): void; } export class MatterServer { @@ -43,7 +43,7 @@ export class MatterServer { private readonly channels = new Array(); private readonly protocolHandlers = new Map(); - + private readonly exchangeCounter = new ExchangeCounter(); private readonly messageCounter = new MessageCounter(); private readonly exchanges = new Map(); private readonly sessionManager = new SessionManager(this); @@ -76,6 +76,14 @@ export class MatterServer { return this.fabricManager; } + initiateExchange(session: Session, channel: ExchangeSocket, protocolId: number) { + const exchangeId = this.exchangeCounter.getIncrementedCounter(); + const exchange = MessageExchange.initiate(session, channel, exchangeId, protocolId, this.messageCounter, () => this.exchanges.delete(exchangeId & 0x10000)); + // Ensure exchangeIds are not colliding in the Map by adding 1 in front of exchanges initiated by this device. + this.exchanges.set(exchangeId & 0x10000, exchange); + return exchange; + } + private onMessage(socket: ExchangeSocket, messageBytes: Buffer) { var packet = MessageCodec.decodePacket(messageBytes); if (packet.header.sessionType === SessionType.Group) throw new Error("Group messages are not supported"); @@ -84,7 +92,7 @@ export class MatterServer { if (session === undefined) throw new Error(`Cannot find a session for ID ${packet.header.sessionId}`); const message = session.decode(packet); - const exchangeId = message.payloadHeader.exchangeId; + const exchangeId = message.payloadHeader.isInitiatorMessage ? message.payloadHeader.exchangeId : message.payloadHeader.exchangeId & 0x10000; if (this.exchanges.has(exchangeId)) { const exchange = this.exchanges.get(exchangeId); exchange?.onMessageReceived(message); @@ -93,8 +101,20 @@ export class MatterServer { this.exchanges.set(exchangeId, exchange); const protocolHandler = this.protocolHandlers.get(message.payloadHeader.protocolId); if (protocolHandler === undefined) throw new Error(`Unsupported protocol ${message.payloadHeader.protocolId}`); - protocolHandler.onNewExchange(exchange); + protocolHandler.onNewExchange(exchange, message); + } + } +} + +class ExchangeCounter { + private exchangeCounter = Crypto.getRandomUInt16(); + + getIncrementedCounter() { + this.exchangeCounter++; + if (this.exchangeCounter > 0xFFFF) { + this.exchangeCounter = 0; } + return this.exchangeCounter; } } diff --git a/src/server/MessageExchange.ts b/src/server/MessageExchange.ts index 359678e5..28cddf8c 100644 --- a/src/server/MessageExchange.ts +++ b/src/server/MessageExchange.ts @@ -1,4 +1,4 @@ -import { Message, MessageCodec } from "../codec/MessageCodec"; +import { Message, MessageCodec, SessionType } from "../codec/MessageCodec"; import { Queue } from "../util/Queue"; import { Session } from "../session/Session"; import { ExchangeSocket, MessageCounter } from "./MatterServer"; @@ -24,44 +24,76 @@ class MessageChannel implements ExchangeSocket { export class MessageExchange { private readonly messageCodec = new MessageCodec(); readonly channel: MessageChannel; + private readonly activeRetransmissionTimeoutMs: number; + private readonly retransmissionRetries: number; + private readonly messagesQueue = new Queue(); + private receivedMessageToAck: Message | undefined; private sentMessageToAck: Message | undefined; private retransmissionTimeoutId: NodeJS.Timeout | undefined; - private activeRetransmissionTimeoutMs: number; - private retransmissionRetries: number; - private receivedMessageToAck: Message | undefined; - private messagesQueue = new Queue(); - constructor( - readonly session: Session, + static fromInitialMessage( + session: Session, channel: ExchangeSocket, - private readonly messageCounter: MessageCounter, - private readonly initialMessage: Message, - private readonly closeCallback: () => void, + messageCounter: MessageCounter, + initialMessage: Message, + closeCallback: () => void, ) { - this.channel = new MessageChannel(channel, session); - this.receivedMessageToAck = initialMessage.payloadHeader.requiresAck ? initialMessage : undefined; - this.messagesQueue.write(initialMessage); - const {activeRetransmissionTimeoutMs: activeRetransmissionTimeoutMs, retransmissionRetries} = session.getMrpParameters(); - this.activeRetransmissionTimeoutMs = activeRetransmissionTimeoutMs; - this.retransmissionRetries = retransmissionRetries; + const exchange = new MessageExchange( + session, + channel, + messageCounter, + false, + session.getId(), + initialMessage.packetHeader.destNodeId, + initialMessage.packetHeader.sourceNodeId, + initialMessage.payloadHeader.exchangeId, + initialMessage.payloadHeader.protocolId, + closeCallback, + ) + exchange.onMessageReceived(initialMessage); + return exchange; } - static fromInitialMessage( + static initiate( session: Session, channel: ExchangeSocket, + exchangeId: number, + protocolId: number, messageCounter: MessageCounter, - initialMessage: Message, closeCallback: () => void, ) { return new MessageExchange( session, channel, messageCounter, - initialMessage, + true, + session.getPeerSessionId(), + session.getNodeId(), + session.getPeerNodeId(), + exchangeId, + protocolId, closeCallback, ); } + constructor( + readonly session: Session, + channel: ExchangeSocket, + private readonly messageCounter: MessageCounter, + private readonly isInitiator: boolean, + private readonly sessionId: number, + private readonly nodeId: bigint | undefined, + private readonly peerNodeId: bigint | undefined, + private readonly exchangeId: number, + private readonly protocolId: number, + private readonly closeCallback: () => void, + ) { + this.channel = new MessageChannel(channel, session); + const {activeRetransmissionTimeoutMs: activeRetransmissionTimeoutMs, retransmissionRetries} = session.getMrpParameters(); + this.activeRetransmissionTimeoutMs = activeRetransmissionTimeoutMs; + this.retransmissionRetries = retransmissionRetries; + } + onMessageReceived(message: Message) { const { packetHeader: { messageId }, payloadHeader: { requiresAck, ackedMessageId, messageType } } = message; @@ -101,20 +133,19 @@ export class MessageExchange { send(messageType: number, payload: Buffer) { if (this.sentMessageToAck !== undefined) throw new Error("The previous message has not been acked yet, cannot send a new message"); - const { packetHeader: { sessionId, sessionType, destNodeId, sourceNodeId }, payloadHeader: { exchangeId, protocolId } } = this.initialMessage; const message = { packetHeader: { - sessionId, - sessionType, + sessionId: this.sessionId, + sessionType: SessionType.Unicast, // TODO: support multicast messageId: this.messageCounter.getIncrementedCounter(), - destNodeId: sourceNodeId, - sourceNodeId: destNodeId, + destNodeId: this.peerNodeId, + sourceNodeId: this.nodeId, }, payloadHeader: { - exchangeId, - protocolId, + exchangeId: this.exchangeId, + protocolId: this.protocolId, messageType, - isInitiatorMessage: false, + isInitiatorMessage: this.isInitiator, requiresAck: true, ackedMessageId: this.receivedMessageToAck?.packetHeader.messageId, }, @@ -131,10 +162,6 @@ export class MessageExchange { return this.messagesQueue.read(); } - getInitialMessageType() { - return this.initialMessage.payloadHeader.messageType; - } - async waitFor(messageType: number) { const message = await this.messagesQueue.read(); const { payloadHeader: { messageType: receivedMessageType } } = message; diff --git a/src/session/SecureSession.ts b/src/session/SecureSession.ts index b9901e82..64e68606 100644 --- a/src/session/SecureSession.ts +++ b/src/session/SecureSession.ts @@ -7,6 +7,7 @@ import { Message, MessageCodec, Packet } from "../codec/MessageCodec"; import { Crypto } from "../crypto/Crypto"; import { Fabric } from "../fabric/Fabric"; +import { SubscriptionHandler } from "../interaction/InteractionProtocol"; import { MatterServer } from "../server/MatterServer"; import { LEBufferWriter } from "../util/LEBufferWriter"; import { DEFAULT_ACTIVE_RETRANSMISSION_TIMEOUT_MS, DEFAULT_IDLE_RETRANSMISSION_TIMEOUT_MS, DEFAULT_RETRANSMISSION_RETRIES, Session } from "./Session"; @@ -15,6 +16,16 @@ const SESSION_KEYS_INFO = Buffer.from("SessionKeys"); export class SecureSession implements Session { private fabric?: Fabric; + private nextSubscriptionId = 0; + private readonly subscriptions = new Array(); + + static async create(matterServer: MatterServer, id: number, nodeId: bigint, peerNodeId: bigint, peerSessionId: number, sharedSecret: Buffer, salt: Buffer, isInitiator: boolean, idleRetransTimeoutMs?: number, activeRetransTimeoutMs?: number) { + const keys = await Crypto.hkdf(sharedSecret, salt, SESSION_KEYS_INFO, 16 * 3); + const decryptKey = isInitiator ? keys.slice(16, 32) : keys.slice(0, 16); + const encryptKey = isInitiator ? keys.slice(0, 16) : keys.slice(16, 32); + const attestationKey = keys.slice(32, 48); + return new SecureSession(matterServer, id, nodeId, peerNodeId, peerSessionId, sharedSecret, decryptKey, encryptKey, attestationKey, idleRetransTimeoutMs, activeRetransTimeoutMs); + } constructor( private readonly matterServer: MatterServer, @@ -31,12 +42,8 @@ export class SecureSession implements Session { private readonly retransmissionRetries: number = DEFAULT_RETRANSMISSION_RETRIES, ) {} - static async create(matterServer: MatterServer, id: number, nodeId: bigint, peerNodeId: bigint, peerSessionId: number, sharedSecret: Buffer, salt: Buffer, isInitiator: boolean, idleRetransTimeoutMs?: number, activeRetransTimeoutMs?: number) { - const keys = await Crypto.hkdf(sharedSecret, salt, SESSION_KEYS_INFO, 16 * 3); - const decryptKey = isInitiator ? keys.slice(16, 32) : keys.slice(0, 16); - const encryptKey = isInitiator ? keys.slice(0, 16) : keys.slice(16, 32); - const attestationKey = keys.slice(32, 48); - return new SecureSession(matterServer, id, nodeId, peerNodeId, peerSessionId, sharedSecret, decryptKey, encryptKey, attestationKey, idleRetransTimeoutMs, activeRetransTimeoutMs); + isSecure(): boolean { + return true; } decode({ header, bytes }: Packet): Message { @@ -76,6 +83,34 @@ export class SecureSession implements Session { return this.matterServer; } + getId() { + return this.id; + } + + getPeerSessionId(): number { + return this.peerSessionId; + } + + getNodeId() { + return this.nodeId; + } + + getPeerNodeId() { + return this.peerNodeId; + } + + addSubscription(subscriptionBuilder: (subscriptionId: number) => SubscriptionHandler): number { + const subscriptionId = this.nextSubscriptionId++; + const subscription = subscriptionBuilder(subscriptionId); + this.subscriptions.push(subscription); + return subscriptionId; + } + + clearSubscriptions() { + this.subscriptions.forEach(subscription => subscription.cancel()); + this.subscriptions.length = 0; + } + private generateNonce(securityFlags: number, messageId: number, nodeId: bigint) { const buffer = new LEBufferWriter(); buffer.writeUInt8(securityFlags); diff --git a/src/session/Session.ts b/src/session/Session.ts index 36bffe3c..1406d682 100644 --- a/src/session/Session.ts +++ b/src/session/Session.ts @@ -13,6 +13,7 @@ interface MrpParameters { } export interface Session { + isSecure(): boolean; getName(): string; decode(packet: Packet): Message; encode(message: Message): Packet; @@ -20,4 +21,8 @@ export interface Session { setFabric(fabric: Fabric): void; getMrpParameters(): MrpParameters; getServer(): MatterServer; + getId(): number; + getPeerSessionId(): number; + getNodeId(): bigint | undefined; + getPeerNodeId(): bigint | undefined; } diff --git a/src/session/UnsecureSession.ts b/src/session/UnsecureSession.ts index 2f2dae1b..e77ba7b4 100644 --- a/src/session/UnsecureSession.ts +++ b/src/session/UnsecureSession.ts @@ -8,12 +8,17 @@ import { Packet, Message, MessageCodec } from "../codec/MessageCodec"; import { Fabric } from "../fabric/Fabric"; import { MatterServer } from "../server/MatterServer"; import { DEFAULT_ACTIVE_RETRANSMISSION_TIMEOUT_MS, DEFAULT_IDLE_RETRANSMISSION_TIMEOUT_MS, DEFAULT_RETRANSMISSION_RETRIES, Session } from "./Session"; +import { UNICAST_UNSECURE_SESSION_ID } from "./SessionManager"; export class UnsecureSession implements Session { constructor( private readonly matterServer: MatterServer, ) {} + isSecure(): boolean { + return false; + } + decode(packet: Packet): Message { return MessageCodec.decodePayload(packet); } @@ -45,4 +50,20 @@ export class UnsecureSession implements Session { getServer() { return this.matterServer; } + + getId(): number { + return UNICAST_UNSECURE_SESSION_ID; + } + + getPeerSessionId(): number { + return UNICAST_UNSECURE_SESSION_ID; + } + + getNodeId() { + return undefined; + } + + getPeerNodeId() { + return undefined; + } } diff --git a/src/session/secure/SecureChannelHandler.ts b/src/session/secure/SecureChannelHandler.ts index 263fc0a0..52ed5193 100644 --- a/src/session/secure/SecureChannelHandler.ts +++ b/src/session/secure/SecureChannelHandler.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { Message } from "../../codec/MessageCodec"; import { ProtocolHandler } from "../../server/MatterServer"; import { MessageExchange } from "../../server/MessageExchange"; import { CasePairing } from "./CasePairing"; @@ -17,14 +18,16 @@ export class SecureChannelHandler implements ProtocolHandler { private readonly caseCommissioner: CasePairing, ) {} - onNewExchange(exchange: MessageExchange) { - const messageType = exchange.getInitialMessageType(); + onNewExchange(exchange: MessageExchange, message: Message) { + const messageType = message.payloadHeader.messageType; switch (messageType) { case MessageType.PbkdfParamRequest: - return this.paseCommissioner.onNewExchange(exchange); + this.paseCommissioner.onNewExchange(exchange); + break; case MessageType.Sigma1: - return this.caseCommissioner.onNewExchange(exchange); + this.caseCommissioner.onNewExchange(exchange); + break; default: throw new Error(`Unexpected initial message on secure channel protocol: ${messageType.toString(16)}`); } From 02af7b2cd5821a1bf4f665bf9dfc2b59004ae55c Mon Sep 17 00:00:00 2001 From: Marco Fucci di Napoli Date: Tue, 13 Sep 2022 07:27:49 -1000 Subject: [PATCH 4/4] Fix unit tests --- test/codec/TlvObjectTest.ts | 6 +++--- test/interaction/InteractionProtocolTest.ts | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/codec/TlvObjectTest.ts b/test/codec/TlvObjectTest.ts index 05a86b27..b93507ee 100644 --- a/test/codec/TlvObjectTest.ts +++ b/test/codec/TlvObjectTest.ts @@ -7,7 +7,7 @@ import assert from "assert"; import { TlvType } from "../../src/codec/TlvCodec"; import { BooleanT, ByteStringT, JsType, ObjectT, Field, TlvObjectCodec, UnsignedIntT, OptionalField } from "../../src/codec/TlvObjectCodec"; -import { ReadResponseT } from "../../src/interaction/InteractionMessages"; +import { DataReportT } from "../../src/interaction/InteractionMessages"; import { TlvTag } from "../../src/codec/TlvTag"; @@ -82,7 +82,7 @@ describe("TlvObjectCodec", () => { }); it("decodes a structure with lists and variable types", () => { - const result = TlvObjectCodec.decode(ENCODED_ARRAY_VARIABLE, ReadResponseT); + const result = TlvObjectCodec.decode(ENCODED_ARRAY_VARIABLE, DataReportT); assert.deepEqual(result, DECODED_ARRAY_VARIABLE); }); @@ -102,7 +102,7 @@ describe("TlvObjectCodec", () => { }); it("encodes a structure with lists and variable types", () => { - const result = TlvObjectCodec.encode(DECODED_ARRAY_VARIABLE, ReadResponseT); + const result = TlvObjectCodec.encode(DECODED_ARRAY_VARIABLE, DataReportT); assert.deepEqual(result.toString("hex"), ENCODED_ARRAY_VARIABLE.toString("hex")); }); diff --git a/test/interaction/InteractionProtocolTest.ts b/test/interaction/InteractionProtocolTest.ts index aede8d1e..2cc4c119 100644 --- a/test/interaction/InteractionProtocolTest.ts +++ b/test/interaction/InteractionProtocolTest.ts @@ -9,7 +9,7 @@ import { BasicCluster } from "../../src/interaction/cluster/BasicCluster"; import { TlvType } from "../../src/codec/TlvCodec"; import { TlvTag } from "../../src/codec/TlvTag"; import { InteractionProtocol } from "../../src/interaction/InteractionProtocol"; -import { ReadRequest, ReadResponse } from "../../src/interaction/InteractionMessenger"; +import { ReadRequest, DataReport } from "../../src/interaction/InteractionMessenger"; import { Device } from "../../src/interaction/model/Device"; import { Endpoint } from "../../src/interaction/model/Endpoint"; import { MessageExchange } from "../../src/server/MessageExchange"; @@ -24,7 +24,7 @@ const READ_REQUEST: ReadRequest = { ], }; -const READ_RESPONSE: ReadResponse = { +const READ_RESPONSE: DataReport = { interactionModelRevision: 1, isFabricFiltered: true, values: [ @@ -63,7 +63,7 @@ describe("InteractionProtocol", () => { it("replies with attribute values", () => { const interactionProtocol = new InteractionProtocol(new Device([ new Endpoint(0, DEVICE.ROOT, [ - new BasicCluster({vendorName: "vendor", vendorId: 1, productName: "product", productId: 2}), + BasicCluster.Builder({ vendorName: "vendor", vendorId: 1, productName: "product", productId: 2 }), ]) ]));