// Copyright © 2023 The Browser Company // SPDX-License-Identifier: BSD-3 import CWinRT // ComPtr is a smart pointer for COM interfaces. It holds on to the underlying pointer // and the semantics of it are meant to mirror that of the ComPtr class in WRL. The // design of ComPtr and ComPtrs.intialize is that there should be no use of UnsafeMutablePointer // anywhere else in the code base. The only place where UnsafeMutablePointer should be used is // where it's required at the ABI boundary. public class ComPtr { fileprivate var pUnk: UnsafeMutablePointer? public init(_ ptr: UnsafeMutablePointer) { self.pUnk = ptr asIUnknown { _ = $0.pointee.lpVtbl.pointee.AddRef($0) } } public convenience init?(_ ptr: UnsafeMutablePointer?) { guard let ptr else { return nil } self.init(ptr) } fileprivate init?(takingOwnership ptr: UnsafeMutablePointer?) { guard let ptr else { return nil } self.pUnk = ptr } // Release ownership of the underlying pointer and return it. This is // useful when assigning to an out parameter and avoids an extra Add/Ref // release call. public func detach() -> UnsafeMutableRawPointer? { let result = pUnk pUnk = nil return UnsafeMutableRawPointer(result) } public func get() -> UnsafeMutablePointer { guard let pUnk else { preconditionFailure("get() called on nil pointer") } return pUnk } deinit { release() } private func release() { guard pUnk != nil else { return } asIUnknown { _ = $0.pointee.lpVtbl.pointee.Release($0) } } func asIUnknown(_ body: (UnsafeMutablePointer) throws -> ResultType) rethrows -> ResultType { guard let pUnk else { preconditionFailure("asIUnknown called on nil pointer") } return try pUnk.withMemoryRebound(to: C_IUnknown.self, capacity: 1) { try body($0) } } } public extension ComPtr { func queryInterface() throws -> Interface { let ptr = try self.asIUnknown { pUnk in var iid = Interface.IID return try ComPtrs.initialize(to: C_IUnknown.self) { result in try CHECKED(pUnk.pointee.lpVtbl.pointee.QueryInterface(pUnk, &iid, &result)) } } return .init(ptr!) } } // ComPtrs properly initializes pointers who have ownership of the underlying raw pointers. This is used at the ABI boundary layer, for example: // let (return1, return2) = try ComPtrs.initialize { return1Abi, return2Abi in // try CHECKED(pThis.pointee.lpVtbl.pointee.Method(pThis, &return1Abi, &return2Abi)) // } public struct ComPtrs { // Note: The single initialization methods still return a tuple for ease of code generation public static func initialize(to: I.Type, _ body: (inout UnsafeMutableRawPointer?) throws -> ()) rethrows -> (ComPtr?) { var ptrRaw: UnsafeMutableRawPointer? try body(&ptrRaw) return (ComPtr(takingOwnership: ptrRaw?.assumingMemoryBound(to: I.self))) } public static func initialize(_ body: (inout UnsafeMutablePointer?) throws -> ()) rethrows -> (ComPtr?) { var ptr: UnsafeMutablePointer? try body(&ptr) return (ComPtr(takingOwnership: ptr)) } public static func initialize(_ body: (inout UnsafeMutablePointer?, inout UnsafeMutablePointer?) throws -> ()) rethrows -> (ComPtr?, ComPtr?) { var ptr1: UnsafeMutablePointer? var ptr2: UnsafeMutablePointer? try body(&ptr1, &ptr2) return (ComPtr(takingOwnership: ptr1), ComPtr(takingOwnership: ptr2)) } public static func initialize(_ body: (inout UnsafeMutablePointer?, inout UnsafeMutablePointer?, inout UnsafeMutablePointer?) throws -> ()) rethrows -> (ComPtr?, ComPtr?, ComPtr?) { var ptr1: UnsafeMutablePointer? var ptr2: UnsafeMutablePointer? var ptr3: UnsafeMutablePointer? try body(&ptr1, &ptr2, &ptr3) return (ComPtr(takingOwnership: ptr1), ComPtr(takingOwnership: ptr2), ComPtr(takingOwnership: ptr3)) } public static func initialize(_ body: (inout UnsafeMutablePointer?, inout UnsafeMutablePointer?, inout UnsafeMutablePointer?, inout UnsafeMutablePointer?) throws -> ()) rethrows -> (ComPtr?, ComPtr?, ComPtr?, ComPtr?) { var ptr1: UnsafeMutablePointer? var ptr2: UnsafeMutablePointer? var ptr3: UnsafeMutablePointer? var ptr4: UnsafeMutablePointer? try body(&ptr1, &ptr2, &ptr3, &ptr4) return (ComPtr(takingOwnership: ptr1), ComPtr(takingOwnership: ptr2), ComPtr(takingOwnership: ptr3), ComPtr(takingOwnership: ptr4)) } public static func initialize(_ body: (inout UnsafeMutablePointer?, inout UnsafeMutablePointer?, inout UnsafeMutablePointer?, inout UnsafeMutablePointer?, inout UnsafeMutablePointer?) throws -> ()) rethrows -> (ComPtr?, ComPtr?, ComPtr?, ComPtr?, ComPtr?) { var ptr1: UnsafeMutablePointer? var ptr2: UnsafeMutablePointer? var ptr3: UnsafeMutablePointer? var ptr4: UnsafeMutablePointer? var ptr5: UnsafeMutablePointer? try body(&ptr1, &ptr2, &ptr3, &ptr4, &ptr5) return (ComPtr(takingOwnership: ptr1), ComPtr(takingOwnership: ptr2), ComPtr(takingOwnership: ptr3), ComPtr(takingOwnership: ptr4), ComPtr(takingOwnership: ptr5)) } }