test_sequences_and_iterators.cpp revision 12037:d28054ac6ec9
1/* 2 tests/test_sequences_and_iterators.cpp -- supporting Pythons' sequence protocol, iterators, 3 etc. 4 5 Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch> 6 7 All rights reserved. Use of this source code is governed by a 8 BSD-style license that can be found in the LICENSE file. 9*/ 10 11#include "pybind11_tests.h" 12#include "constructor_stats.h" 13#include <pybind11/operators.h> 14#include <pybind11/stl.h> 15 16class Sequence { 17public: 18 Sequence(size_t size) : m_size(size) { 19 print_created(this, "of size", m_size); 20 m_data = new float[size]; 21 memset(m_data, 0, sizeof(float) * size); 22 } 23 24 Sequence(const std::vector<float> &value) : m_size(value.size()) { 25 print_created(this, "of size", m_size, "from std::vector"); 26 m_data = new float[m_size]; 27 memcpy(m_data, &value[0], sizeof(float) * m_size); 28 } 29 30 Sequence(const Sequence &s) : m_size(s.m_size) { 31 print_copy_created(this); 32 m_data = new float[m_size]; 33 memcpy(m_data, s.m_data, sizeof(float)*m_size); 34 } 35 36 Sequence(Sequence &&s) : m_size(s.m_size), m_data(s.m_data) { 37 print_move_created(this); 38 s.m_size = 0; 39 s.m_data = nullptr; 40 } 41 42 ~Sequence() { 43 print_destroyed(this); 44 delete[] m_data; 45 } 46 47 Sequence &operator=(const Sequence &s) { 48 if (&s != this) { 49 delete[] m_data; 50 m_size = s.m_size; 51 m_data = new float[m_size]; 52 memcpy(m_data, s.m_data, sizeof(float)*m_size); 53 } 54 55 print_copy_assigned(this); 56 57 return *this; 58 } 59 60 Sequence &operator=(Sequence &&s) { 61 if (&s != this) { 62 delete[] m_data; 63 m_size = s.m_size; 64 m_data = s.m_data; 65 s.m_size = 0; 66 s.m_data = nullptr; 67 } 68 69 print_move_assigned(this); 70 71 return *this; 72 } 73 74 bool operator==(const Sequence &s) const { 75 if (m_size != s.size()) 76 return false; 77 for (size_t i=0; i<m_size; ++i) 78 if (m_data[i] != s[i]) 79 return false; 80 return true; 81 } 82 83 bool operator!=(const Sequence &s) const { 84 return !operator==(s); 85 } 86 87 float operator[](size_t index) const { 88 return m_data[index]; 89 } 90 91 float &operator[](size_t index) { 92 return m_data[index]; 93 } 94 95 bool contains(float v) const { 96 for (size_t i=0; i<m_size; ++i) 97 if (v == m_data[i]) 98 return true; 99 return false; 100 } 101 102 Sequence reversed() const { 103 Sequence result(m_size); 104 for (size_t i=0; i<m_size; ++i) 105 result[m_size-i-1] = m_data[i]; 106 return result; 107 } 108 109 size_t size() const { return m_size; } 110 111 const float *begin() const { return m_data; } 112 const float *end() const { return m_data+m_size; } 113 114private: 115 size_t m_size; 116 float *m_data; 117}; 118 119class IntPairs { 120public: 121 IntPairs(std::vector<std::pair<int, int>> data) : data_(std::move(data)) {} 122 const std::pair<int, int>* begin() const { return data_.data(); } 123 124private: 125 std::vector<std::pair<int, int>> data_; 126}; 127 128// Interface of a map-like object that isn't (directly) an unordered_map, but provides some basic 129// map-like functionality. 130class StringMap { 131public: 132 StringMap() = default; 133 StringMap(std::unordered_map<std::string, std::string> init) 134 : map(std::move(init)) {} 135 136 void set(std::string key, std::string val) { 137 map[key] = val; 138 } 139 140 std::string get(std::string key) const { 141 return map.at(key); 142 } 143 144 size_t size() const { 145 return map.size(); 146 } 147 148private: 149 std::unordered_map<std::string, std::string> map; 150 151public: 152 decltype(map.cbegin()) begin() const { return map.cbegin(); } 153 decltype(map.cend()) end() const { return map.cend(); } 154}; 155 156template<typename T> 157class NonZeroIterator { 158 const T* ptr_; 159public: 160 NonZeroIterator(const T* ptr) : ptr_(ptr) {} 161 const T& operator*() const { return *ptr_; } 162 NonZeroIterator& operator++() { ++ptr_; return *this; } 163}; 164 165class NonZeroSentinel {}; 166 167template<typename A, typename B> 168bool operator==(const NonZeroIterator<std::pair<A, B>>& it, const NonZeroSentinel&) { 169 return !(*it).first || !(*it).second; 170} 171 172template <typename PythonType> 173py::list test_random_access_iterator(PythonType x) { 174 if (x.size() < 5) 175 throw py::value_error("Please provide at least 5 elements for testing."); 176 177 auto checks = py::list(); 178 auto assert_equal = [&checks](py::handle a, py::handle b) { 179 auto result = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ); 180 if (result == -1) { throw py::error_already_set(); } 181 checks.append(result != 0); 182 }; 183 184 auto it = x.begin(); 185 assert_equal(x[0], *it); 186 assert_equal(x[0], it[0]); 187 assert_equal(x[1], it[1]); 188 189 assert_equal(x[1], *(++it)); 190 assert_equal(x[1], *(it++)); 191 assert_equal(x[2], *it); 192 assert_equal(x[3], *(it += 1)); 193 assert_equal(x[2], *(--it)); 194 assert_equal(x[2], *(it--)); 195 assert_equal(x[1], *it); 196 assert_equal(x[0], *(it -= 1)); 197 198 assert_equal(it->attr("real"), x[0].attr("real")); 199 assert_equal((it + 1)->attr("real"), x[1].attr("real")); 200 201 assert_equal(x[1], *(it + 1)); 202 assert_equal(x[1], *(1 + it)); 203 it += 3; 204 assert_equal(x[1], *(it - 2)); 205 206 checks.append(static_cast<std::size_t>(x.end() - x.begin()) == x.size()); 207 checks.append((x.begin() + static_cast<std::ptrdiff_t>(x.size())) == x.end()); 208 checks.append(x.begin() < x.end()); 209 210 return checks; 211} 212 213test_initializer sequences_and_iterators([](py::module &pm) { 214 auto m = pm.def_submodule("sequences_and_iterators"); 215 216 py::class_<Sequence> seq(m, "Sequence"); 217 218 seq.def(py::init<size_t>()) 219 .def(py::init<const std::vector<float>&>()) 220 /// Bare bones interface 221 .def("__getitem__", [](const Sequence &s, size_t i) { 222 if (i >= s.size()) 223 throw py::index_error(); 224 return s[i]; 225 }) 226 .def("__setitem__", [](Sequence &s, size_t i, float v) { 227 if (i >= s.size()) 228 throw py::index_error(); 229 s[i] = v; 230 }) 231 .def("__len__", &Sequence::size) 232 /// Optional sequence protocol operations 233 .def("__iter__", [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); }, 234 py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) 235 .def("__contains__", [](const Sequence &s, float v) { return s.contains(v); }) 236 .def("__reversed__", [](const Sequence &s) -> Sequence { return s.reversed(); }) 237 /// Slicing protocol (optional) 238 .def("__getitem__", [](const Sequence &s, py::slice slice) -> Sequence* { 239 size_t start, stop, step, slicelength; 240 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) 241 throw py::error_already_set(); 242 Sequence *seq = new Sequence(slicelength); 243 for (size_t i=0; i<slicelength; ++i) { 244 (*seq)[i] = s[start]; start += step; 245 } 246 return seq; 247 }) 248 .def("__setitem__", [](Sequence &s, py::slice slice, const Sequence &value) { 249 size_t start, stop, step, slicelength; 250 if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) 251 throw py::error_already_set(); 252 if (slicelength != value.size()) 253 throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); 254 for (size_t i=0; i<slicelength; ++i) { 255 s[start] = value[i]; start += step; 256 } 257 }) 258 /// Comparisons 259 .def(py::self == py::self) 260 .def(py::self != py::self); 261 // Could also define py::self + py::self for concatenation, etc. 262 263 py::class_<StringMap> map(m, "StringMap"); 264 265 map .def(py::init<>()) 266 .def(py::init<std::unordered_map<std::string, std::string>>()) 267 .def("__getitem__", [](const StringMap &map, std::string key) { 268 try { return map.get(key); } 269 catch (const std::out_of_range&) { 270 throw py::key_error("key '" + key + "' does not exist"); 271 } 272 }) 273 .def("__setitem__", &StringMap::set) 274 .def("__len__", &StringMap::size) 275 .def("__iter__", [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); }, 276 py::keep_alive<0, 1>()) 277 .def("items", [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); }, 278 py::keep_alive<0, 1>()) 279 ; 280 281 py::class_<IntPairs>(m, "IntPairs") 282 .def(py::init<std::vector<std::pair<int, int>>>()) 283 .def("nonzero", [](const IntPairs& s) { 284 return py::make_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel()); 285 }, py::keep_alive<0, 1>()) 286 .def("nonzero_keys", [](const IntPairs& s) { 287 return py::make_key_iterator(NonZeroIterator<std::pair<int, int>>(s.begin()), NonZeroSentinel()); 288 }, py::keep_alive<0, 1>()); 289 290 291#if 0 292 // Obsolete: special data structure for exposing custom iterator types to python 293 // kept here for illustrative purposes because there might be some use cases which 294 // are not covered by the much simpler py::make_iterator 295 296 struct PySequenceIterator { 297 PySequenceIterator(const Sequence &seq, py::object ref) : seq(seq), ref(ref) { } 298 299 float next() { 300 if (index == seq.size()) 301 throw py::stop_iteration(); 302 return seq[index++]; 303 } 304 305 const Sequence &seq; 306 py::object ref; // keep a reference 307 size_t index = 0; 308 }; 309 310 py::class_<PySequenceIterator>(seq, "Iterator") 311 .def("__iter__", [](PySequenceIterator &it) -> PySequenceIterator& { return it; }) 312 .def("__next__", &PySequenceIterator::next); 313 314 On the actual Sequence object, the iterator would be constructed as follows: 315 .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast<const Sequence &>(), s); }) 316#endif 317 318 m.def("object_to_list", [](py::object o) { 319 auto l = py::list(); 320 for (auto item : o) { 321 l.append(item); 322 } 323 return l; 324 }); 325 326 m.def("iterator_to_list", [](py::iterator it) { 327 auto l = py::list(); 328 while (it != py::iterator::sentinel()) { 329 l.append(*it); 330 ++it; 331 } 332 return l; 333 }); 334 335 // Make sure that py::iterator works with std algorithms 336 m.def("count_none", [](py::object o) { 337 return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); 338 }); 339 340 m.def("find_none", [](py::object o) { 341 auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); 342 return it->is_none(); 343 }); 344 345 m.def("count_nonzeros", [](py::dict d) { 346 return std::count_if(d.begin(), d.end(), [](std::pair<py::handle, py::handle> p) { 347 return p.second.cast<int>() != 0; 348 }); 349 }); 350 351 m.def("tuple_iterator", [](py::tuple x) { return test_random_access_iterator(x); }); 352 m.def("list_iterator", [](py::list x) { return test_random_access_iterator(x); }); 353 m.def("sequence_iterator", [](py::sequence x) { return test_random_access_iterator(x); }); 354}); 355