test_sequences_and_iterators.cpp revision 14299:2fbea9df56d2
113540Sandrea.mondelli@ucf.edu/* 24130Ssaidi@eecs.umich.edu tests/test_sequences_and_iterators.cpp -- supporting Pythons' sequence protocol, iterators, 31897Sstever@eecs.umich.edu etc. 41897Sstever@eecs.umich.edu 51897Sstever@eecs.umich.edu Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 61897Sstever@eecs.umich.edu 71897Sstever@eecs.umich.edu All rights reserved. Use of this source code is governed by a 81897Sstever@eecs.umich.edu BSD-style license that can be found in the LICENSE file. 91897Sstever@eecs.umich.edu*/ 101897Sstever@eecs.umich.edu 111897Sstever@eecs.umich.edu#include "pybind11_tests.h" 121897Sstever@eecs.umich.edu#include "constructor_stats.h" 131897Sstever@eecs.umich.edu#include <pybind11/operators.h> 141897Sstever@eecs.umich.edu#include <pybind11/stl.h> 151897Sstever@eecs.umich.edu 161897Sstever@eecs.umich.edutemplate<typename T> 171897Sstever@eecs.umich.educlass NonZeroIterator { 181897Sstever@eecs.umich.edu const T* ptr_; 191897Sstever@eecs.umich.edupublic: 201897Sstever@eecs.umich.edu NonZeroIterator(const T* ptr) : ptr_(ptr) {} 211897Sstever@eecs.umich.edu const T& operator*() const { return *ptr_; } 221897Sstever@eecs.umich.edu NonZeroIterator& operator++() { ++ptr_; return *this; } 231897Sstever@eecs.umich.edu}; 241897Sstever@eecs.umich.edu 251897Sstever@eecs.umich.educlass NonZeroSentinel {}; 261897Sstever@eecs.umich.edu 271897Sstever@eecs.umich.edutemplate<typename A, typename B> 281897Sstever@eecs.umich.edubool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentinel&) { 291897Sstever@eecs.umich.edu return !(*it).first || !(*it).second; 301897Sstever@eecs.umich.edu} 311897Sstever@eecs.umich.edu 321897Sstever@eecs.umich.edutemplate <typename PythonType> 331897Sstever@eecs.umich.edupy::list test_random_access_iterator(PythonType x) { 344961Ssaidi@eecs.umich.edu if (x.size() < 5) 351897Sstever@eecs.umich.edu throw py::value_error("Please provide at least 5 elements for testing."); 361897Sstever@eecs.umich.edu 371897Sstever@eecs.umich.edu auto checks = py::list(); 381897Sstever@eecs.umich.edu auto assert_equal = [&checks](py::handle a, py::handle b) { 397047Snate@binkert.org auto result = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ); 408319Ssteve.reinhardt@amd.com if (result == -1) { throw py::error_already_set(); } 417047Snate@binkert.org checks.append(result != 0); 428319Ssteve.reinhardt@amd.com }; 4311706Sandreas.hansson@arm.com 448811Sandreas.hansson@arm.com auto it = x.begin(); 459850Sandreas.hansson@arm.com assert_equal(x[0], *it); 4611706Sandreas.hansson@arm.com assert_equal(x[0], it[0]); 4711706Sandreas.hansson@arm.com assert_equal(x[1], it[1]); 4811706Sandreas.hansson@arm.com 4911706Sandreas.hansson@arm.com assert_equal(x[1], *(++it)); 508811Sandreas.hansson@arm.com assert_equal(x[1], *(it++)); 518811Sandreas.hansson@arm.com assert_equal(x[2], *it); 5210007Snilay@cs.wisc.edu assert_equal(x[3], *(it += 1)); 5311308Santhony.gutierrez@amd.com assert_equal(x[2], *(--it)); 5411730Sar4jc@virginia.edu assert_equal(x[2], *(it--)); 5511308Santhony.gutierrez@amd.com assert_equal(x[1], *it); 567047Snate@binkert.org assert_equal(x[0], *(it -= 1)); 578811Sandreas.hansson@arm.com 588811Sandreas.hansson@arm.com assert_equal(it->attr("real"), x[0].attr("real")); 598811Sandreas.hansson@arm.com assert_equal((it + 1)->attr("real"), x[1].attr("real")); 608319Ssteve.reinhardt@amd.com 618319Ssteve.reinhardt@amd.com assert_equal(x[1], *(it + 1)); 628319Ssteve.reinhardt@amd.com assert_equal(x[1], *(1 + it)); 638319Ssteve.reinhardt@amd.com it += 3; 648319Ssteve.reinhardt@amd.com assert_equal(x[1], *(it - 2)); 658319Ssteve.reinhardt@amd.com 668319Ssteve.reinhardt@amd.com checks.append(static_cast<std::size_t>(x.end() - x.begin()) == x.size()); 677047Snate@binkert.org checks.append((x.begin() + static_cast<std::ptrdiff_t>(x.size())) == x.end()); 688319Ssteve.reinhardt@amd.com checks.append(x.begin() < x.end()); 698319Ssteve.reinhardt@amd.com 707047Snate@binkert.org return checks; 717047Snate@binkert.org} 728319Ssteve.reinhardt@amd.com 738319Ssteve.reinhardt@amd.comTEST_SUBMODULE(sequences_and_iterators, m) { 748319Ssteve.reinhardt@amd.com // test_sliceable 757047Snate@binkert.org class Sliceable{ 767047Snate@binkert.org public: 777047Snate@binkert.org Sliceable(int n): size(n) {} 781897Sstever@eecs.umich.edu int start,stop,step; 791897Sstever@eecs.umich.edu int size; 801897Sstever@eecs.umich.edu }; 811897Sstever@eecs.umich.edu py::class_<Sliceable>(m,"Sliceable") 828319Ssteve.reinhardt@amd.com .def(py::init<int>()) 838319Ssteve.reinhardt@amd.com .def("__getitem__",[](const Sliceable &s, py::slice slice) { 848319Ssteve.reinhardt@amd.com ssize_t start, stop, step, slicelength; 858319Ssteve.reinhardt@amd.com if (!slice.compute(s.size, &start, &stop, &step, &slicelength)) 868319Ssteve.reinhardt@amd.com throw py::error_already_set(); 878319Ssteve.reinhardt@amd.com int istart = static_cast<int>(start); 888319Ssteve.reinhardt@amd.com int istop = static_cast<int>(stop); 891897Sstever@eecs.umich.edu int istep = static_cast<int>(step); 908319Ssteve.reinhardt@amd.com return std::make_tuple(istart,istop,istep); 918811Sandreas.hansson@arm.com }) 928319Ssteve.reinhardt@amd.com ; 938319Ssteve.reinhardt@amd.com 941897Sstever@eecs.umich.edu // test_sequence 957047Snate@binkert.org class Sequence { 967047Snate@binkert.org public: 971897Sstever@eecs.umich.edu Sequence(size_t size) : m_size(size) { 981897Sstever@eecs.umich.edu print_created(this, "of size", m_size); 994961Ssaidi@eecs.umich.edu m_data = new float[size]; 1004961Ssaidi@eecs.umich.edu memset(m_data, 0, sizeof(float) * size); 1014961Ssaidi@eecs.umich.edu } 1024961Ssaidi@eecs.umich.edu Sequence(const std::vector<float> &value) : m_size(value.size()) { 1034961Ssaidi@eecs.umich.edu print_created(this, "of size", m_size, "from std::vector"); 1044961Ssaidi@eecs.umich.edu m_data = new float[m_size]; 1054961Ssaidi@eecs.umich.edu memcpy(m_data, &value[0], sizeof(float) * m_size); 1064961Ssaidi@eecs.umich.edu } 1074961Ssaidi@eecs.umich.edu Sequence(const Sequence &s) : m_size(s.m_size) { 1084961Ssaidi@eecs.umich.edu print_copy_created(this); 1094961Ssaidi@eecs.umich.edu m_data = new float[m_size]; 1104961Ssaidi@eecs.umich.edu memcpy(m_data, s.m_data, sizeof(float)*m_size); 1114961Ssaidi@eecs.umich.edu } 1124961Ssaidi@eecs.umich.edu Sequence(Sequence &&s) : m_size(s.m_size), m_data(s.m_data) { 1131897Sstever@eecs.umich.edu print_move_created(this); 1148319Ssteve.reinhardt@amd.com s.m_size = 0; 1151897Sstever@eecs.umich.edu s.m_data = nullptr; 1168319Ssteve.reinhardt@amd.com } 1178319Ssteve.reinhardt@amd.com 1188816Sgblack@eecs.umich.edu ~Sequence() { print_destroyed(this); delete[] m_data; } 1198319Ssteve.reinhardt@amd.com 1208319Ssteve.reinhardt@amd.com Sequence &operator=(const Sequence &s) { 1218319Ssteve.reinhardt@amd.com if (&s != this) { 1228811Sandreas.hansson@arm.com delete[] m_data; 1234961Ssaidi@eecs.umich.edu m_size = s.m_size; 1248319Ssteve.reinhardt@amd.com m_data = new float[m_size]; 1258811Sandreas.hansson@arm.com memcpy(m_data, s.m_data, sizeof(float)*m_size); 1268814Sgblack@eecs.umich.edu } 1278319Ssteve.reinhardt@amd.com print_copy_assigned(this); 1288811Sandreas.hansson@arm.com return *this; 1298811Sandreas.hansson@arm.com } 1308811Sandreas.hansson@arm.com 1318811Sandreas.hansson@arm.com Sequence &operator=(Sequence &&s) { 1328811Sandreas.hansson@arm.com if (&s != this) { 1338811Sandreas.hansson@arm.com delete[] m_data; 1348811Sandreas.hansson@arm.com m_size = s.m_size; 1358811Sandreas.hansson@arm.com m_data = s.m_data; 1368811Sandreas.hansson@arm.com s.m_size = 0; 1371897Sstever@eecs.umich.edu s.m_data = nullptr; 1387047Snate@binkert.org } 1397047Snate@binkert.org print_move_assigned(this); 1407047Snate@binkert.org return *this; 1417047Snate@binkert.org } 1427047Snate@binkert.org 1437047Snate@binkert.org bool operator==(const Sequence &s) const { 1447047Snate@binkert.org if (m_size != s.size()) return false; 1457047Snate@binkert.org for (size_t i = 0; i < m_size; ++i) 1467047Snate@binkert.org if (m_data[i] != s[i]) 1477047Snate@binkert.org return false; 1487047Snate@binkert.org return true; 1497047Snate@binkert.org } 1507047Snate@binkert.org bool operator!=(const Sequence &s) const { return !operator==(s); } 1517047Snate@binkert.org 1524961Ssaidi@eecs.umich.edu float operator[](size_t index) const { return m_data[index]; } 1534961Ssaidi@eecs.umich.edu float &operator[](size_t index) { return m_data[index]; } 1547047Snate@binkert.org 1557047Snate@binkert.org bool contains(float v) const { 1564961Ssaidi@eecs.umich.edu for (size_t i = 0; i < m_size; ++i) 1575247Sstever@gmail.com if (v == m_data[i]) 1585247Sstever@gmail.com return true; 1598319Ssteve.reinhardt@amd.com return false; 1608319Ssteve.reinhardt@amd.com } 1613725Sstever@eecs.umich.edu 1629843Ssteve.reinhardt@amd.com Sequence reversed() const { 1639843Ssteve.reinhardt@amd.com Sequence result(m_size); 1649843Ssteve.reinhardt@amd.com for (size_t i = 0; i < m_size; ++i) 1659843Ssteve.reinhardt@amd.com result[m_size - i - 1] = m_data[i]; 1669843Ssteve.reinhardt@amd.com return result; 1679843Ssteve.reinhardt@amd.com } 1688120Sgblack@eecs.umich.edu 1697047Snate@binkert.org size_t size() const { return m_size; } 1707047Snate@binkert.org 1717047Snate@binkert.org const float *begin() const { return m_data; } 1727047Snate@binkert.org const float *end() const { return m_data+m_size; } 1737047Snate@binkert.org 174 private: 175 size_t m_size; 176 float *m_data; 177 }; 178 py::class_<Sequence>(m, "Sequence") 179 .def(py::init<size_t>()) 180 .def(py::init<const std::vector<float>&>()) 181 /// Bare bones interface 182 .def("__getitem__", [](const Sequence &s, size_t i) { 183 if (i >= s.size()) throw py::index_error(); 184 return s[i]; 185 }) 186 .def("__setitem__", [](Sequence &s, size_t i, float v) { 187 if (i >= s.size()) throw py::index_error(); 188 s[i] = v; 189 }) 190 .def("__len__", &Sequence::size) 191 /// Optional sequence protocol operations 192 .def("__iter__", [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); }, 193 py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) 194 .def("__contains__", [](const Sequence &s, float v) { return s.contains(v); }) 195 .def("__reversed__", [](const Sequence &s) -> Sequence { return s.reversed(); }) 196 /// Slicing protocol (optional) 197 .def("__getitem__", [](const Sequence &s, py::slice slice) -> Sequence* { 198 size_t start, stop, step, slicelength; 199 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) 200 throw py::error_already_set(); 201 Sequence *seq = new Sequence(slicelength); 202 for (size_t i = 0; i < slicelength; ++i) { 203 (*seq)[i] = s[start]; start += step; 204 } 205 return seq; 206 }) 207 .def("__setitem__", [](Sequence &s, py::slice slice, const Sequence &value) { 208 size_t start, stop, step, slicelength; 209 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) 210 throw py::error_already_set(); 211 if (slicelength != value.size()) 212 throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); 213 for (size_t i = 0; i < slicelength; ++i) { 214 s[start] = value[i]; start += step; 215 } 216 }) 217 /// Comparisons 218 .def(py::self == py::self) 219 .def(py::self != py::self) 220 // Could also define py::self + py::self for concatenation, etc. 221 ; 222 223 // test_map_iterator 224 // Interface of a map-like object that isn't (directly) an unordered_map, but provides some basic 225 // map-like functionality. 226 class StringMap { 227 public: 228 StringMap() = default; 229 StringMap(std::unordered_map<std::string, std::string> init) 230 : map(std::move(init)) {} 231 232 void set(std::string key, std::string val) { map[key] = val; } 233 std::string get(std::string key) const { return map.at(key); } 234 size_t size() const { return map.size(); } 235 private: 236 std::unordered_map<std::string, std::string> map; 237 public: 238 decltype(map.cbegin()) begin() const { return map.cbegin(); } 239 decltype(map.cend()) end() const { return map.cend(); } 240 }; 241 py::class_<StringMap>(m, "StringMap") 242 .def(py::init<>()) 243 .def(py::init<std::unordered_map<std::string, std::string>>()) 244 .def("__getitem__", [](const StringMap &map, std::string key) { 245 try { return map.get(key); } 246 catch (const std::out_of_range&) { 247 throw py::key_error("key '" + key + "' does not exist"); 248 } 249 }) 250 .def("__setitem__", &StringMap::set) 251 .def("__len__", &StringMap::size) 252 .def("__iter__", [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); }, 253 py::keep_alive<0, 1>()) 254 .def("items", [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); }, 255 py::keep_alive<0, 1>()) 256 ; 257 258 // test_generalized_iterators 259 class IntPairs { 260 public: 261 IntPairs(std::vector<std::pair<int, int>> data) : data_(std::move(data)) {} 262 const std::pair<int, int>* begin() const { return data_.data(); } 263 private: 264 std::vector<std::pair<int, int>> data_; 265 }; 266 py::class_<IntPairs>(m, "IntPairs") 267 .def(py::init<std::vector<std::pair<int, int>>>()) 268 .def("nonzero", [](const IntPairs& s) { 269 return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel()); 270 }, py::keep_alive<0, 1>()) 271 .def("nonzero_keys", [](const IntPairs& s) { 272 return py::make_key_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel()); 273 }, py::keep_alive<0, 1>()) 274 ; 275 276 277#if 0 278 // Obsolete: special data structure for exposing custom iterator types to python 279 // kept here for illustrative purposes because there might be some use cases which 280 // are not covered by the much simpler py::make_iterator 281 282 struct PySequenceIterator { 283 PySequenceIterator(const Sequence &seq, py::object ref) : seq(seq), ref(ref) { } 284 285 float next() { 286 if (index == seq.size()) 287 throw py::stop_iteration(); 288 return seq[index++]; 289 } 290 291 const Sequence &seq; 292 py::object ref; // keep a reference 293 size_t index = 0; 294 }; 295 296 py::class_<PySequenceIterator>(seq, "Iterator") 297 .def("__iter__", [](PySequenceIterator &it) -> PySequenceIterator& { return it; }) 298 .def("__next__", &PySequenceIterator::next); 299 300 On the actual Sequence object, the iterator would be constructed as follows: 301 .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); }) 302#endif 303 304 // test_python_iterator_in_cpp 305 m.def("object_to_list", [](py::object o) { 306 auto l = py::list(); 307 for (auto item : o) { 308 l.append(item); 309 } 310 return l; 311 }); 312 313 m.def("iterator_to_list", [](py::iterator it) { 314 auto l = py::list(); 315 while (it != py::iterator::sentinel()) { 316 l.append(*it); 317 ++it; 318 } 319 return l; 320 }); 321 322 // Make sure that py::iterator works with std algorithms 323 m.def("count_none", [](py::object o) { 324 return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); 325 }); 326 327 m.def("find_none", [](py::object o) { 328 auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); 329 return it->is_none(); 330 }); 331 332 m.def("count_nonzeros", [](py::dict d) { 333 return std::count_if(d.begin(), d.end(), [](std::pair<py::handle, py::handle> p) { 334 return p.second.cast<int>() != 0; 335 }); 336 }); 337 338 m.def("tuple_iterator", &test_random_access_iterator<py::tuple>); 339 m.def("list_iterator", &test_random_access_iterator<py::list>); 340 m.def("sequence_iterator", &test_random_access_iterator<py::sequence>); 341 342 // test_iterator_passthrough 343 // #181: iterator passthrough did not compile 344 m.def("iterator_passthrough", [](py::iterator s) -> py::iterator { 345 return py::make_iterator(std::begin(s), std::end(s)); 346 }); 347 348 // test_iterator_rvp 349 // #388: Can't make iterators via make_iterator() with different r/v policies 350 static std::vector<int> list = { 1, 2, 3 }; 351 m.def("make_iterator_1", []() { return py::make_iterator<py::return_value_policy::copy>(list); }); 352 m.def("make_iterator_2", []() { return py::make_iterator<py::return_value_policy::automatic>(list); }); 353} 354