master
  1/**
  2 * This file has no copyright assigned and is placed in the Public Domain.
  3 * This file is part of the mingw-w64 runtime package.
  4 * No warranty is given; refer to the file DISCLAIMER.PD within this package.
  5 */
  6
  7#ifndef _WRL_CLIENT_H_
  8#define _WRL_CLIENT_H_
  9
 10#include <cstddef>
 11#include <unknwn.h>
 12/* #include <weakreference.h> */
 13#include <roapi.h>
 14
 15/* #include <wrl/def.h> */
 16#include <wrl/internal.h>
 17
 18namespace Microsoft {
 19    namespace WRL {
 20        namespace Details {
 21            template <typename T> class ComPtrRefBase {
 22            protected:
 23                T* ptr_;
 24
 25            public:
 26                typedef typename T::InterfaceType InterfaceType;
 27
 28#ifndef __WRL_CLASSIC_COM__
 29                operator IInspectable**() const throw()  {
 30                    static_assert(__is_base_of(IInspectable, InterfaceType), "Invalid cast");
 31                    return reinterpret_cast<IInspectable**>(ptr_->ReleaseAndGetAddressOf());
 32                }
 33#endif
 34
 35                operator IUnknown**() const throw() {
 36                    static_assert(__is_base_of(IUnknown, InterfaceType), "Invalid cast");
 37                    return reinterpret_cast<IUnknown**>(ptr_->ReleaseAndGetAddressOf());
 38                }
 39            };
 40
 41            template <typename T> class ComPtrRef : public Details::ComPtrRefBase<T> {
 42            public:
 43                ComPtrRef(T *ptr) throw() {
 44                    ComPtrRefBase<T>::ptr_ = ptr;
 45                }
 46
 47                operator void**() const throw() {
 48                    return reinterpret_cast<void**>(ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf());
 49                }
 50
 51                operator T*() throw() {
 52                    *ComPtrRefBase<T>::ptr_ = nullptr;
 53                    return ComPtrRefBase<T>::ptr_;
 54                }
 55
 56                operator typename ComPtrRefBase<T>::InterfaceType**() throw() {
 57                    return ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf();
 58                }
 59
 60                typename ComPtrRefBase<T>::InterfaceType *operator*() throw() {
 61                    return ComPtrRefBase<T>::ptr_->Get();
 62                }
 63
 64                typename ComPtrRefBase<T>::InterfaceType *const *GetAddressOf() const throw() {
 65                    return ComPtrRefBase<T>::ptr_->GetAddressOf();
 66                }
 67
 68                typename ComPtrRefBase<T>::InterfaceType **ReleaseAndGetAddressOf() throw() {
 69                    return ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf();
 70                }
 71            };
 72
 73        }
 74
 75        template<typename T> class ComPtr {
 76        public:
 77            typedef T InterfaceType;
 78
 79            ComPtr() throw() : ptr_(nullptr) {}
 80            ComPtr(decltype(nullptr)) throw() : ptr_(nullptr) {}
 81
 82            template<class U> ComPtr(U *other) throw() : ptr_(other) {
 83                InternalAddRef();
 84            }
 85
 86            ComPtr(const ComPtr &other) throw() : ptr_(other.ptr_) {
 87                InternalAddRef();
 88            }
 89
 90            template<class U>
 91            ComPtr(const ComPtr<U> &other) throw() : ptr_(other.Get()) {
 92                InternalAddRef();
 93            }
 94
 95            ComPtr(ComPtr &&other) throw() : ptr_(nullptr) {
 96                if(this != reinterpret_cast<ComPtr*>(&reinterpret_cast<unsigned char&>(other)))
 97                    Swap(other);
 98            }
 99
100            template<class U>
101            ComPtr(ComPtr<U>&& other) throw() : ptr_(other.Detach()) {}
102
103            ~ComPtr() throw() {
104                InternalRelease();
105            }
106
107            ComPtr &operator=(decltype(nullptr)) throw() {
108                InternalRelease();
109                return *this;
110            }
111
112            ComPtr &operator=(InterfaceType *other) throw() {
113                if (ptr_ != other) {
114                    InternalRelease();
115                    ptr_ = other;
116                    InternalAddRef();
117                }
118                return *this;
119            }
120
121            template<typename U>
122            ComPtr &operator=(U *other) throw()  {
123                if (ptr_ != other) {
124                    InternalRelease();
125                    ptr_ = other;
126                    InternalAddRef();
127                }
128                return *this;
129            }
130
131            ComPtr& operator=(const ComPtr &other) throw() {
132                if (ptr_ != other.ptr_)
133                    ComPtr(other).Swap(*this);
134                return *this;
135            }
136
137            template<class U>
138            ComPtr &operator=(const ComPtr<U> &other) throw() {
139                ComPtr(other).Swap(*this);
140                return *this;
141            }
142
143            ComPtr& operator=(ComPtr &&other) throw() {
144                ComPtr(other).Swap(*this);
145                return *this;
146            }
147
148            template<class U>
149            ComPtr& operator=(ComPtr<U> &&other) throw() {
150                ComPtr(other).Swap(*this);
151                return *this;
152            }
153
154            void Swap(ComPtr &&r) throw() {
155                InterfaceType *tmp = ptr_;
156                ptr_ = r.ptr_;
157                r.ptr_ = tmp;
158            }
159
160            void Swap(ComPtr &r) throw() {
161                InterfaceType *tmp = ptr_;
162                ptr_ = r.ptr_;
163                r.ptr_ = tmp;
164            }
165
166            operator Details::BoolType() const throw() {
167                return Get() != nullptr ? &Details::BoolStruct::Member : nullptr;
168            }
169
170            InterfaceType *Get() const throw()  {
171                return ptr_;
172            }
173
174            InterfaceType *operator->() const throw() {
175                return ptr_;
176            }
177
178            Details::ComPtrRef<ComPtr<T>> operator&() throw()  {
179                return Details::ComPtrRef<ComPtr<T>>(this);
180            }
181
182            const Details::ComPtrRef<const ComPtr<T>> operator&() const throw() {
183                return Details::ComPtrRef<const ComPtr<T>>(this);
184            }
185
186            InterfaceType *const *GetAddressOf() const throw() {
187                return &ptr_;
188            }
189
190            InterfaceType **GetAddressOf() throw() {
191                return &ptr_;
192            }
193
194            InterfaceType **ReleaseAndGetAddressOf() throw() {
195                InternalRelease();
196                return &ptr_;
197            }
198
199            InterfaceType *Detach() throw() {
200                T* ptr = ptr_;
201                ptr_ = nullptr;
202                return ptr;
203            }
204
205            void Attach(InterfaceType *other) throw() {
206                if (ptr_ != other) {
207                    InternalRelease();
208                    ptr_ = other;
209                }
210            }
211
212            unsigned long Reset() {
213                return InternalRelease();
214            }
215
216            HRESULT CopyTo(InterfaceType **ptr) const throw() {
217                InternalAddRef();
218                *ptr = ptr_;
219                return S_OK;
220            }
221
222            HRESULT CopyTo(REFIID riid, void **ptr) const throw() {
223                return ptr_->QueryInterface(riid, ptr);
224            }
225
226            template<typename U>
227            HRESULT CopyTo(U **ptr) const throw() {
228                return ptr_->QueryInterface(__uuidof(U), reinterpret_cast<void**>(ptr));
229            }
230
231            template<typename U>
232            HRESULT As(Details::ComPtrRef<ComPtr<U>> p) const throw() {
233                return ptr_->QueryInterface(__uuidof(U), p);
234            }
235
236            template<typename U>
237            HRESULT As(ComPtr<U> *p) const throw() {
238                return ptr_->QueryInterface(__uuidof(U), reinterpret_cast<void**>(p->ReleaseAndGetAddressOf()));
239            }
240
241            HRESULT AsIID(REFIID riid, ComPtr<IUnknown> *p) const throw() {
242                return ptr_->QueryInterface(riid, reinterpret_cast<void**>(p->ReleaseAndGetAddressOf()));
243            }
244
245            /*
246            HRESULT AsWeak(WeakRef *pWeakRef) const throw() {
247                return ::Microsoft::WRL::AsWeak(ptr_, pWeakRef);
248            }
249            */
250        protected:
251            InterfaceType *ptr_;
252
253            void InternalAddRef() const throw() {
254                if(ptr_)
255                    ptr_->AddRef();
256            }
257
258            unsigned long InternalRelease() throw() {
259                InterfaceType *tmp = ptr_;
260                if(!tmp)
261                    return 0;
262                ptr_ = nullptr;
263                return tmp->Release();
264            }
265        };
266
267        template <class T, class U>
268        bool operator==(const ComPtr<T> &a, const ComPtr<U> &b) throw()
269        {
270            static_assert(__is_base_of(T, U) || __is_base_of(U, T), "Type incompatible");
271            return a.Get() == b.Get();
272        }
273
274        template <class T>
275        bool operator==(const ComPtr<T> &a, std::nullptr_t) throw()
276        {
277            return a.Get() == nullptr;
278        }
279
280        template <class T>
281        bool operator==(std::nullptr_t, const ComPtr<T> &a) throw()
282        {
283            return a.Get() == nullptr;
284        }
285
286        template <class T, class U>
287        bool operator!=(const ComPtr<T> &a, const ComPtr<U> &b) throw()
288        {
289            static_assert(__is_base_of(T, U) || __is_base_of(U, T), "Type incompatible");
290            return a.Get() != b.Get();
291        }
292
293        template <class T>
294        bool operator!=(const ComPtr<T> &a, std::nullptr_t) throw()
295        {
296            return a.Get() != nullptr;
297        }
298
299        template <class T>
300        bool operator!=(std::nullptr_t, const ComPtr<T> &a) throw()
301        {
302            return a.Get() != nullptr;
303        }
304
305        template <class T, class U>
306        bool operator<(const ComPtr<T> &a, const ComPtr<U> &b) throw()
307        {
308            static_assert(__is_base_of(T, U) || __is_base_of(U, T), "Type incompatible");
309            return a.Get() < b.Get();
310        }
311    }
312}
313
314template<typename T>
315void **IID_PPV_ARGS_Helper(::Microsoft::WRL::Details::ComPtrRef<T> pp) throw() {
316    static_assert(__is_base_of(IUnknown, typename T::InterfaceType), "Expected COM interface");
317    return pp;
318}
319
320namespace Windows {
321    namespace Foundation {
322        template<typename T>
323        inline HRESULT ActivateInstance(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> instance) throw() {
324            return ActivateInstance(classid, instance.ReleaseAndGetAddressOf());
325        }
326
327        template<typename T>
328        inline HRESULT GetActivationFactory(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> factory) throw() {
329            return RoGetActivationFactory(classid, IID_INS_ARGS(factory.ReleaseAndGetAddressOf()));
330        }
331    }
332}
333
334namespace ABI {
335    namespace Windows {
336        namespace Foundation {
337            template<typename T>
338            inline HRESULT ActivateInstance(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> instance) throw() {
339                return ActivateInstance(classid, instance.ReleaseAndGetAddressOf());
340            }
341
342            template<typename T>
343            inline HRESULT GetActivationFactory(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> factory) throw() {
344                return RoGetActivationFactory(classid, IID_INS_ARGS(factory.ReleaseAndGetAddressOf()));
345            }
346        }
347    }
348}
349
350#endif