message_map_container.cc 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. // Protocol Buffers - Google's data interchange format
  2. // Copyright 2008 Google Inc. All rights reserved.
  3. // https://developers.google.com/protocol-buffers/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are
  7. // met:
  8. //
  9. // * Redistributions of source code must retain the above copyright
  10. // notice, this list of conditions and the following disclaimer.
  11. // * Redistributions in binary form must reproduce the above
  12. // copyright notice, this list of conditions and the following disclaimer
  13. // in the documentation and/or other materials provided with the
  14. // distribution.
  15. // * Neither the name of Google Inc. nor the names of its
  16. // contributors may be used to endorse or promote products derived from
  17. // this software without specific prior written permission.
  18. //
  19. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  20. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  21. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  22. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  23. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  24. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  25. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  26. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  27. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  28. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. // Author: haberman@google.com (Josh Haberman)
  31. #include <google/protobuf/pyext/message_map_container.h>
  32. #include <google/protobuf/stubs/logging.h>
  33. #include <google/protobuf/stubs/common.h>
  34. #include <google/protobuf/message.h>
  35. #include <google/protobuf/pyext/message.h>
  36. #include <google/protobuf/pyext/scoped_pyobject_ptr.h>
  37. namespace google {
  38. namespace protobuf {
  39. namespace python {
  40. struct MessageMapIterator {
  41. PyObject_HEAD;
  42. // This dict contains the full contents of what we want to iterate over.
  43. // There's no way to avoid building this, because the list representation
  44. // (which is canonical) can contain duplicate keys. So at the very least we
  45. // need a set that lets us skip duplicate keys. And at the point that we're
  46. // doing that, we might as well just build the actual dict we're iterating
  47. // over and use dict's built-in iterator.
  48. PyObject* dict;
  49. // An iterator on dict.
  50. PyObject* iter;
  51. // A pointer back to the container, so we can notice changes to the version.
  52. MessageMapContainer* container;
  53. // The version of the map when we took the iterator to it.
  54. //
  55. // We store this so that if the map is modified during iteration we can throw
  56. // an error.
  57. uint64 version;
  58. };
  59. static MessageMapIterator* GetIter(PyObject* obj) {
  60. return reinterpret_cast<MessageMapIterator*>(obj);
  61. }
  62. namespace message_map_container {
  63. static MessageMapContainer* GetMap(PyObject* obj) {
  64. return reinterpret_cast<MessageMapContainer*>(obj);
  65. }
  66. // The private constructor of MessageMapContainer objects.
  67. PyObject* NewContainer(CMessage* parent,
  68. const google::protobuf::FieldDescriptor* parent_field_descriptor,
  69. PyObject* concrete_class) {
  70. if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
  71. return NULL;
  72. }
  73. #if PY_MAJOR_VERSION >= 3
  74. PyObject* obj = PyType_GenericAlloc(
  75. reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
  76. #else
  77. PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
  78. #endif
  79. if (obj == NULL) {
  80. return PyErr_Format(PyExc_RuntimeError,
  81. "Could not allocate new container.");
  82. }
  83. MessageMapContainer* self = GetMap(obj);
  84. self->message = parent->message;
  85. self->parent = parent;
  86. self->parent_field_descriptor = parent_field_descriptor;
  87. self->owner = parent->owner;
  88. self->version = 0;
  89. self->key_field_descriptor =
  90. parent_field_descriptor->message_type()->FindFieldByName("key");
  91. self->value_field_descriptor =
  92. parent_field_descriptor->message_type()->FindFieldByName("value");
  93. self->message_dict = PyDict_New();
  94. if (self->message_dict == NULL) {
  95. return PyErr_Format(PyExc_RuntimeError,
  96. "Could not allocate message dict.");
  97. }
  98. Py_INCREF(concrete_class);
  99. self->subclass_init = concrete_class;
  100. if (self->key_field_descriptor == NULL ||
  101. self->value_field_descriptor == NULL) {
  102. Py_DECREF(obj);
  103. return PyErr_Format(PyExc_KeyError,
  104. "Map entry descriptor did not have key/value fields");
  105. }
  106. return obj;
  107. }
  108. // Initializes the underlying Message object of "to" so it becomes a new parent
  109. // repeated scalar, and copies all the values from "from" to it. A child scalar
  110. // container can be released by passing it as both from and to (e.g. making it
  111. // the recipient of the new parent message and copying the values from itself).
  112. static int InitializeAndCopyToParentContainer(
  113. MessageMapContainer* from,
  114. MessageMapContainer* to) {
  115. // For now we require from == to, re-evaluate if we want to support deep copy
  116. // as in repeated_composite_container.cc.
  117. GOOGLE_DCHECK(from == to);
  118. Message* old_message = from->message;
  119. Message* new_message = old_message->New();
  120. to->parent = NULL;
  121. to->parent_field_descriptor = from->parent_field_descriptor;
  122. to->message = new_message;
  123. to->owner.reset(new_message);
  124. vector<const FieldDescriptor*> fields;
  125. fields.push_back(from->parent_field_descriptor);
  126. old_message->GetReflection()->SwapFields(old_message, new_message, fields);
  127. return 0;
  128. }
  129. static PyObject* GetCMessage(MessageMapContainer* self, Message* entry) {
  130. // Get or create the CMessage object corresponding to this message.
  131. Message* message = entry->GetReflection()->MutableMessage(
  132. entry, self->value_field_descriptor);
  133. ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
  134. PyObject* ret = PyDict_GetItem(self->message_dict, key);
  135. if (ret == NULL) {
  136. CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
  137. message->GetDescriptor());
  138. ret = reinterpret_cast<PyObject*>(cmsg);
  139. if (cmsg == NULL) {
  140. return NULL;
  141. }
  142. cmsg->owner = self->owner;
  143. cmsg->message = message;
  144. cmsg->parent = self->parent;
  145. if (PyDict_SetItem(self->message_dict, key, ret) < 0) {
  146. Py_DECREF(ret);
  147. return NULL;
  148. }
  149. } else {
  150. Py_INCREF(ret);
  151. }
  152. return ret;
  153. }
  154. int Release(MessageMapContainer* self) {
  155. InitializeAndCopyToParentContainer(self, self);
  156. return 0;
  157. }
  158. void SetOwner(MessageMapContainer* self,
  159. const shared_ptr<Message>& new_owner) {
  160. self->owner = new_owner;
  161. }
  162. Py_ssize_t Length(PyObject* _self) {
  163. MessageMapContainer* self = GetMap(_self);
  164. google::protobuf::Message* message = self->message;
  165. return message->GetReflection()->FieldSize(*message,
  166. self->parent_field_descriptor);
  167. }
  168. int MapKeyMatches(MessageMapContainer* self, const Message* entry,
  169. PyObject* key) {
  170. // TODO(haberman): do we need more strict type checking?
  171. ScopedPyObjectPtr entry_key(
  172. cmessage::InternalGetScalar(entry, self->key_field_descriptor));
  173. int ret = PyObject_RichCompareBool(key, entry_key, Py_EQ);
  174. return ret;
  175. }
  176. int SetItem(PyObject *_self, PyObject *key, PyObject *v) {
  177. if (v) {
  178. PyErr_Format(PyExc_ValueError,
  179. "Direct assignment of submessage not allowed");
  180. return -1;
  181. }
  182. // Now we know that this is a delete, not a set.
  183. MessageMapContainer* self = GetMap(_self);
  184. cmessage::AssureWritable(self->parent);
  185. Message* message = self->message;
  186. const Reflection* reflection = message->GetReflection();
  187. size_t size =
  188. reflection->FieldSize(*message, self->parent_field_descriptor);
  189. // Right now the Reflection API doesn't support map lookup, so we implement it
  190. // via linear search. We need to search from the end because the underlying
  191. // representation can have duplicates if a user calls MergeFrom(); the last
  192. // one needs to win.
  193. //
  194. // TODO(haberman): add lookup API to Reflection API.
  195. bool found = false;
  196. for (int i = size - 1; i >= 0; i--) {
  197. Message* entry = reflection->MutableRepeatedMessage(
  198. message, self->parent_field_descriptor, i);
  199. int matches = MapKeyMatches(self, entry, key);
  200. if (matches < 0) return -1;
  201. if (matches) {
  202. found = true;
  203. if (i != size - 1) {
  204. reflection->SwapElements(message, self->parent_field_descriptor, i,
  205. size - 1);
  206. }
  207. reflection->RemoveLast(message, self->parent_field_descriptor);
  208. // Can't exit now, the repeated field representation of maps allows
  209. // duplicate keys, and we have to be sure to remove all of them.
  210. }
  211. }
  212. if (!found) {
  213. PyErr_Format(PyExc_KeyError, "Key not present in map");
  214. return -1;
  215. }
  216. self->version++;
  217. return 0;
  218. }
  219. PyObject* GetIterator(PyObject *_self) {
  220. MessageMapContainer* self = GetMap(_self);
  221. ScopedPyObjectPtr obj(PyType_GenericAlloc(&MessageMapIterator_Type, 0));
  222. if (obj == NULL) {
  223. return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
  224. }
  225. MessageMapIterator* iter = GetIter(obj);
  226. Py_INCREF(self);
  227. iter->container = self;
  228. iter->version = self->version;
  229. iter->dict = PyDict_New();
  230. if (iter->dict == NULL) {
  231. return PyErr_Format(PyExc_RuntimeError,
  232. "Could not allocate dict for iterator.");
  233. }
  234. // Build the entire map into a dict right now. Start from the beginning so
  235. // that later entries win in the case of duplicates.
  236. Message* message = self->message;
  237. const Reflection* reflection = message->GetReflection();
  238. // Right now the Reflection API doesn't support map lookup, so we implement it
  239. // via linear search. We need to search from the end because the underlying
  240. // representation can have duplicates if a user calls MergeFrom(); the last
  241. // one needs to win.
  242. //
  243. // TODO(haberman): add lookup API to Reflection API.
  244. size_t size =
  245. reflection->FieldSize(*message, self->parent_field_descriptor);
  246. for (int i = size - 1; i >= 0; i--) {
  247. Message* entry = reflection->MutableRepeatedMessage(
  248. message, self->parent_field_descriptor, i);
  249. ScopedPyObjectPtr key(
  250. cmessage::InternalGetScalar(entry, self->key_field_descriptor));
  251. if (PyDict_SetItem(iter->dict, key.get(), GetCMessage(self, entry)) < 0) {
  252. return PyErr_Format(PyExc_RuntimeError,
  253. "SetItem failed in iterator construction.");
  254. }
  255. }
  256. iter->iter = PyObject_GetIter(iter->dict);
  257. return obj.release();
  258. }
  259. PyObject* GetItem(PyObject* _self, PyObject* key) {
  260. MessageMapContainer* self = GetMap(_self);
  261. cmessage::AssureWritable(self->parent);
  262. Message* message = self->message;
  263. const Reflection* reflection = message->GetReflection();
  264. // Right now the Reflection API doesn't support map lookup, so we implement it
  265. // via linear search. We need to search from the end because the underlying
  266. // representation can have duplicates if a user calls MergeFrom(); the last
  267. // one needs to win.
  268. //
  269. // TODO(haberman): add lookup API to Reflection API.
  270. size_t size =
  271. reflection->FieldSize(*message, self->parent_field_descriptor);
  272. for (int i = size - 1; i >= 0; i--) {
  273. Message* entry = reflection->MutableRepeatedMessage(
  274. message, self->parent_field_descriptor, i);
  275. int matches = MapKeyMatches(self, entry, key);
  276. if (matches < 0) return NULL;
  277. if (matches) {
  278. return GetCMessage(self, entry);
  279. }
  280. }
  281. // Key is not already present; insert a new entry.
  282. Message* entry =
  283. reflection->AddMessage(message, self->parent_field_descriptor);
  284. self->version++;
  285. if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor,
  286. key) < 0) {
  287. reflection->RemoveLast(message, self->parent_field_descriptor);
  288. return NULL;
  289. }
  290. return GetCMessage(self, entry);
  291. }
  292. PyObject* Contains(PyObject* _self, PyObject* key) {
  293. MessageMapContainer* self = GetMap(_self);
  294. Message* message = self->message;
  295. const Reflection* reflection = message->GetReflection();
  296. // Right now the Reflection API doesn't support map lookup, so we implement it
  297. // via linear search.
  298. //
  299. // TODO(haberman): add lookup API to Reflection API.
  300. size_t size =
  301. reflection->FieldSize(*message, self->parent_field_descriptor);
  302. for (int i = 0; i < size; i++) {
  303. Message* entry = reflection->MutableRepeatedMessage(
  304. message, self->parent_field_descriptor, i);
  305. int matches = MapKeyMatches(self, entry, key);
  306. if (matches < 0) return NULL;
  307. if (matches) {
  308. Py_RETURN_TRUE;
  309. }
  310. }
  311. Py_RETURN_FALSE;
  312. }
  313. PyObject* Clear(PyObject* _self) {
  314. MessageMapContainer* self = GetMap(_self);
  315. cmessage::AssureWritable(self->parent);
  316. Message* message = self->message;
  317. const Reflection* reflection = message->GetReflection();
  318. self->version++;
  319. reflection->ClearField(message, self->parent_field_descriptor);
  320. Py_RETURN_NONE;
  321. }
  322. PyObject* Get(PyObject* self, PyObject* args) {
  323. PyObject* key;
  324. PyObject* default_value = NULL;
  325. if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
  326. return NULL;
  327. }
  328. ScopedPyObjectPtr is_present(Contains(self, key));
  329. if (is_present.get() == NULL) {
  330. return NULL;
  331. }
  332. if (PyObject_IsTrue(is_present.get())) {
  333. return GetItem(self, key);
  334. } else {
  335. if (default_value != NULL) {
  336. Py_INCREF(default_value);
  337. return default_value;
  338. } else {
  339. Py_RETURN_NONE;
  340. }
  341. }
  342. }
  343. static PyMappingMethods MpMethods = {
  344. Length, // mp_length
  345. GetItem, // mp_subscript
  346. SetItem, // mp_ass_subscript
  347. };
  348. static void Dealloc(PyObject* _self) {
  349. MessageMapContainer* self = GetMap(_self);
  350. self->owner.reset();
  351. Py_DECREF(self->message_dict);
  352. Py_TYPE(_self)->tp_free(_self);
  353. }
  354. static PyMethodDef Methods[] = {
  355. { "__contains__", (PyCFunction)Contains, METH_O,
  356. "Tests whether the map contains this element."},
  357. { "clear", (PyCFunction)Clear, METH_NOARGS,
  358. "Removes all elements from the map."},
  359. { "get", Get, METH_VARARGS,
  360. "Gets the value for the given key if present, or otherwise a default" },
  361. { "get_or_create", GetItem, METH_O,
  362. "Alias for getitem, useful to make explicit that the map is mutated." },
  363. /*
  364. { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
  365. "Makes a deep copy of the class." },
  366. { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
  367. "Outputs picklable representation of the repeated field." },
  368. */
  369. {NULL, NULL},
  370. };
  371. } // namespace message_map_container
  372. namespace message_map_iterator {
  373. static void Dealloc(PyObject* _self) {
  374. MessageMapIterator* self = GetIter(_self);
  375. Py_DECREF(self->dict);
  376. Py_DECREF(self->iter);
  377. Py_DECREF(self->container);
  378. Py_TYPE(_self)->tp_free(_self);
  379. }
  380. PyObject* IterNext(PyObject* _self) {
  381. MessageMapIterator* self = GetIter(_self);
  382. // This won't catch mutations to the map performed by MergeFrom(); no easy way
  383. // to address that.
  384. if (self->version != self->container->version) {
  385. return PyErr_Format(PyExc_RuntimeError,
  386. "Map modified during iteration.");
  387. }
  388. return PyIter_Next(self->iter);
  389. }
  390. } // namespace message_map_iterator
  391. #if PY_MAJOR_VERSION >= 3
  392. static PyType_Slot MessageMapContainer_Type_slots[] = {
  393. {Py_tp_dealloc, (void *)message_map_container::Dealloc},
  394. {Py_mp_length, (void *)message_map_container::Length},
  395. {Py_mp_subscript, (void *)message_map_container::GetItem},
  396. {Py_mp_ass_subscript, (void *)message_map_container::SetItem},
  397. {Py_tp_methods, (void *)message_map_container::Methods},
  398. {Py_tp_iter, (void *)message_map_container::GetIterator},
  399. {0, 0}
  400. };
  401. PyType_Spec MessageMapContainer_Type_spec = {
  402. FULL_MODULE_NAME ".MessageMapContainer",
  403. sizeof(MessageMapContainer),
  404. 0,
  405. Py_TPFLAGS_DEFAULT,
  406. MessageMapContainer_Type_slots
  407. };
  408. PyObject *MessageMapContainer_Type;
  409. #else
  410. PyTypeObject MessageMapContainer_Type = {
  411. PyVarObject_HEAD_INIT(&PyType_Type, 0)
  412. FULL_MODULE_NAME ".MessageMapContainer", // tp_name
  413. sizeof(MessageMapContainer), // tp_basicsize
  414. 0, // tp_itemsize
  415. message_map_container::Dealloc, // tp_dealloc
  416. 0, // tp_print
  417. 0, // tp_getattr
  418. 0, // tp_setattr
  419. 0, // tp_compare
  420. 0, // tp_repr
  421. 0, // tp_as_number
  422. 0, // tp_as_sequence
  423. &message_map_container::MpMethods, // tp_as_mapping
  424. 0, // tp_hash
  425. 0, // tp_call
  426. 0, // tp_str
  427. 0, // tp_getattro
  428. 0, // tp_setattro
  429. 0, // tp_as_buffer
  430. Py_TPFLAGS_DEFAULT, // tp_flags
  431. "A map container for message", // tp_doc
  432. 0, // tp_traverse
  433. 0, // tp_clear
  434. 0, // tp_richcompare
  435. 0, // tp_weaklistoffset
  436. message_map_container::GetIterator, // tp_iter
  437. 0, // tp_iternext
  438. message_map_container::Methods, // tp_methods
  439. 0, // tp_members
  440. 0, // tp_getset
  441. 0, // tp_base
  442. 0, // tp_dict
  443. 0, // tp_descr_get
  444. 0, // tp_descr_set
  445. 0, // tp_dictoffset
  446. 0, // tp_init
  447. };
  448. #endif
  449. PyTypeObject MessageMapIterator_Type = {
  450. PyVarObject_HEAD_INIT(&PyType_Type, 0)
  451. FULL_MODULE_NAME ".MessageMapIterator", // tp_name
  452. sizeof(MessageMapIterator), // tp_basicsize
  453. 0, // tp_itemsize
  454. message_map_iterator::Dealloc, // tp_dealloc
  455. 0, // tp_print
  456. 0, // tp_getattr
  457. 0, // tp_setattr
  458. 0, // tp_compare
  459. 0, // tp_repr
  460. 0, // tp_as_number
  461. 0, // tp_as_sequence
  462. 0, // tp_as_mapping
  463. 0, // tp_hash
  464. 0, // tp_call
  465. 0, // tp_str
  466. 0, // tp_getattro
  467. 0, // tp_setattro
  468. 0, // tp_as_buffer
  469. Py_TPFLAGS_DEFAULT, // tp_flags
  470. "A scalar map iterator", // tp_doc
  471. 0, // tp_traverse
  472. 0, // tp_clear
  473. 0, // tp_richcompare
  474. 0, // tp_weaklistoffset
  475. PyObject_SelfIter, // tp_iter
  476. message_map_iterator::IterNext, // tp_iternext
  477. 0, // tp_methods
  478. 0, // tp_members
  479. 0, // tp_getset
  480. 0, // tp_base
  481. 0, // tp_dict
  482. 0, // tp_descr_get
  483. 0, // tp_descr_set
  484. 0, // tp_dictoffset
  485. 0, // tp_init
  486. };
  487. } // namespace python
  488. } // namespace protobuf
  489. } // namespace google