aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCarl Friedrich Bolz-Tereick <cfbolz@gmx.de>2021-02-26 11:58:51 +0100
committerCarl Friedrich Bolz-Tereick <cfbolz@gmx.de>2021-02-26 11:58:51 +0100
commit28fab560253e17f0e6f5b5f2c2f5d443cf50155f (patch)
tree5b2bec6be7c0f483a354b0c09150d167b444e345 /rpython
parentadd a random test for finding (diff)
downloadpypy-28fab560253e17f0e6f5b5f2c2f5d443cf50155f.tar.gz
pypy-28fab560253e17f0e6f5b5f2c2f5d443cf50155f.tar.bz2
pypy-28fab560253e17f0e6f5b5f2c2f5d443cf50155f.zip
follow what cpython is doing more systematically:
add similar cases, stop using StringBuilder, make a correctly sized llstr directly. needs a refactoring
Diffstat (limited to 'rpython')
-rw-r--r--rpython/rlib/rstring.py130
-rw-r--r--rpython/rlib/test/test_rstring.py16
2 files changed, 117 insertions, 29 deletions
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
index ed7a61d734..c77a364069 100644
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -251,34 +251,10 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
builder.append_slice(input, upper, len(input))
replacements = upper + 1
- elif isinstance(sub, str) and len(sub) == 1:
- # a copy of the code that is specialized for single (ascii) characters
- sub = sub[0]
- cnt = count(input, sub, 0, len(input))
- if cnt == 0:
- return input, 0
- if maxsplit > 0 and cnt > maxsplit:
- cnt = maxsplit
- diff_len = len(by) - 1
- try:
- result_size = ovfcheck(diff_len * cnt)
- result_size = ovfcheck(result_size + len(input))
- except OverflowError:
- raise
- replacements = cnt
-
- builder = Builder(result_size)
- start = 0
- while maxsplit != 0:
- next = find(input, sub, start, len(input))
- if next < 0:
- break
- builder.append_slice(input, start, next)
- builder.append(by)
- start = next + 1
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
-
- builder.append_slice(input, start, len(input))
+ elif isinstance(input, str) and len(sub) == 1:
+ if len(by) == 1:
+ return replace_count_str_chr_chr(input, sub[0], by[0], maxsplit)
+ return replace_count_str_chr_str(input, sub[0], by, maxsplit)
else:
# First compute the exact result size
@@ -286,6 +262,8 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
cnt = count(input, sub, 0, len(input))
if isinstance(input, str) and cnt == 0:
return input, 0
+ if isinstance(input, str):
+ return replace_count_str_str_str(input, sub, by, cnt, maxsplit)
else:
assert isutf8
from rpython.rlib import rutf8
@@ -330,6 +308,102 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
return builder.build(), replacements
+def replace_count_str_chr_chr(input, c1, c2, maxsplit):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ s = llstr(input)
+ length = len(s.chars)
+ start = find(input, c1, 0, len(input))
+ if start < 0:
+ return input, 0
+ newstr = s.malloc(length)
+ src = s.chars
+ dst = newstr.chars
+ s.copy_contents(s, newstr, 0, 0, len(input))
+ dst[start] = c2
+ count = 1
+ start += 1
+ maxsplit -= 1
+ while maxsplit != 0:
+ next = find(input, c1, start, len(input))
+ if next < 0:
+ break
+ dst[next] = c2
+ start = next + 1
+ maxsplit -= 1
+ count += 1
+
+ return hlstr(newstr), count
+
+def replace_count_str_chr_str(input, sub, by, maxsplit):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ cnt = count(input, sub, 0, len(input))
+ if cnt == 0:
+ return input, 0
+ if maxsplit > 0 and cnt > maxsplit:
+ cnt = maxsplit
+ diff_len = len(by) - 1
+ try:
+ result_size = ovfcheck(diff_len * cnt)
+ result_size = ovfcheck(result_size + len(input))
+ except OverflowError:
+ raise
+
+ s = llstr(input)
+ by_ll = llstr(by)
+
+ newstr = s.malloc(result_size)
+ dst = 0
+ start = 0
+ while maxsplit != 0:
+ next = find(input, sub, start, len(input))
+ if next < 0:
+ break
+ s.copy_contents(s, newstr, start, dst, next - start)
+ dst += next - start
+ s.copy_contents(by_ll, newstr, 0, dst, len(by))
+ dst += len(by)
+
+ start = next + 1
+ maxsplit -= 1 # NB. if it's already < 0, it stays < 0
+
+ s.copy_contents(s, newstr, start, dst, len(input) - start)
+ assert dst - start + len(input) == result_size
+ return hlstr(newstr), cnt
+
+def replace_count_str_str_str(input, sub, by, cnt, maxsplit):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ if cnt > maxsplit and maxsplit > 0:
+ cnt = maxsplit
+ diff_len = len(by) - len(sub)
+ try:
+ result_size = ovfcheck(diff_len * cnt)
+ result_size = ovfcheck(result_size + len(input))
+ except OverflowError:
+ raise
+
+ s = llstr(input)
+ by_ll = llstr(by)
+ newstr = s.malloc(result_size)
+ sublen = len(sub)
+ bylen = len(by)
+ inputlen = len(input)
+ dst = 0
+ start = 0
+ while maxsplit != 0:
+ next = find(input, sub, start, inputlen)
+ if next < 0:
+ break
+ s.copy_contents(s, newstr, start, dst, next - start)
+ dst += next - start
+ s.copy_contents(by_ll, newstr, 0, dst, bylen)
+ dst += bylen
+ start = next + sublen
+ maxsplit -= 1 # NB. if it's already < 0, it stays < 0
+ s.copy_contents(s, newstr, start, dst, len(input) - start)
+ assert dst - start + len(input) == result_size
+ return hlstr(newstr), cnt
+
+
def _normalize_start_end(length, start, end):
if start < 0:
start += length
diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py
index 66c23dbd52..a2ce14e0a8 100644
--- a/rpython/rlib/test/test_rstring.py
+++ b/rpython/rlib/test/test_rstring.py
@@ -6,7 +6,7 @@ from rpython.rlib.rstring import find, rfind, count, _search, SEARCH_COUNT, SEAR
from rpython.rlib.buffer import StringBuffer
from rpython.rtyper.test.tool import BaseRtypingTest
-from hypothesis import given, strategies as st
+from hypothesis import given, strategies as st, assume
def test_split():
def check_split(value, sub, *args, **kwargs):
@@ -326,3 +326,17 @@ def test_hypothesis_search(u, prefix, suffix):
count = _search(s, u, 0, len(s), SEARCH_COUNT)
assert count == s.count(u)
assert 1 <= count
+
+
+@given(st.text(), st.lists(st.text(), min_size=2), st.text(), st.integers(min_value=0, max_value=1000000))
+def test_hypothesis_search(needle, pieces, by, maxcount):
+ needle = needle.encode("utf-8")
+ pieces = [piece.encode("utf-8") for piece in pieces]
+ by = by.encode("utf-8")
+ input = needle.join(pieces)
+ assume(len(input) > 0)
+
+ if needle == '' and pieces == [] and by == '0' and maxcount == 1:
+ import pdb; pdb.set_trace()
+ res = replace(input, needle, by, maxcount)
+ assert res == input.replace(needle, by, maxcount)