From 59fc047e5db498307fb1ba2ca0b71620674c872c Mon Sep 17 00:00:00 2001 From: Matti Picus Date: Sat, 1 May 2021 21:58:07 +0300 Subject: PyUnicode_Contains accepts bytes on python2 (comment to issue 3400) --- pypy/module/cpyext/test/test_unicodeobject.py | 48 ++++++++++++++++++++++----- pypy/module/cpyext/unicodeobject.py | 14 +++++--- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/pypy/module/cpyext/test/test_unicodeobject.py b/pypy/module/cpyext/test/test_unicodeobject.py index 7791198c6e..1ec1f13eb7 100644 --- a/pypy/module/cpyext/test/test_unicodeobject.py +++ b/pypy/module/cpyext/test/test_unicodeobject.py @@ -182,6 +182,45 @@ class AppTestUnicodeObject(AppTestCpythonExtensionBase): tz1 = time.tzname[1] assert module.lower(tz1) == tz1.lower() + def test_contains(self): + import sys + module = self.import_extension('foo', [ + ("contains", "METH_VARARGS", + """ + PyObject *arg1 = PyTuple_GetItem(args, 0); + PyObject *arg2 = PyTuple_GetItem(args, 1); + int ret = PyUnicode_Contains(arg1, arg2); + if (ret < 0) { + return NULL; + } + return PyLong_FromLong(ret); + """)]) + s = u"abcabcabc" + assert module.contains(s, u"a") == 1 + assert module.contains(s, u"e") == 0 + try: + module.contains(s, 1) + except TypeError: + pass + else: + assert False + try: + module.contains(1, u"a") + except TypeError: + pass + else: + assert False + if sys.version_info < (3, 0): + assert module.contains(b'abcdef', b'e') == 1 + else: + try: + module.contains(b'abcdef', b'e') + except TypeError: + pass + else: + assert False + + class TestUnicode(BaseApiTest): def test_unicodeobject(self, space): @@ -695,15 +734,6 @@ class TestUnicode(BaseApiTest): assert PyUnicode_Find(space, w_str, space.wrap(u"c"), 0, 4, -1) == 2 assert PyUnicode_Find(space, w_str, space.wrap(u"z"), 0, 4, -1) == -1 - def test_contains(self, space): - w_str = space.wrap(u"abcabcd") - assert PyUnicode_Contains(space, w_str, space.wrap(u"a")) == 1 - assert PyUnicode_Contains(space, w_str, space.wrap(u"e")) == 0 - with raises_w(space, TypeError): - PyUnicode_Contains(space, w_str, space.wrap(1)) == -1 - with raises_w(space, TypeError) as e: - PyUnicode_Contains(space, space.wrap(1), space.wrap(u"a")) == -1 - def test_split(self, space): w_str = space.wrap(u"a\nb\nc\nd") assert "[u'a', u'b', u'c', u'd']" == space.unwrap(space.repr( diff --git a/pypy/module/cpyext/unicodeobject.py b/pypy/module/cpyext/unicodeobject.py index 8dfb2509df..33b96dcb4d 100644 --- a/pypy/module/cpyext/unicodeobject.py +++ b/pypy/module/cpyext/unicodeobject.py @@ -770,11 +770,17 @@ def PyUnicode_Contains(space, w_str, w_substr): element has to coerce to a one element Unicode string. -1 is returned if there was an error.""" if not space.isinstance_w(w_substr, space.w_unicode): - raise oefmt(space.w_TypeError, - "in requires string as left operand, not %T", - w_substr) + if space.isinstance_w(w_substr, space.w_bytes): + w_substr = space.call_method(w_substr, 'decode') + else: + raise oefmt(space.w_TypeError, + "in requires string as left operand, not %T", + w_substr) if not space.isinstance_w(w_str, space.w_unicode): - raise oefmt(space.w_TypeError, "must be str, not %T", w_str) + if space.isinstance_w(w_str, space.w_bytes): + w_str = space.call_method(w_str, 'decode') + else: + raise oefmt(space.w_TypeError, "must be str, not %T", w_str) return space.int_w(space.call_method(w_str, '__contains__', w_substr)) @cpython_api([PyObject, PyObject, Py_ssize_t], PyObject) -- cgit v1.2.3-65-gdbad