improved binary search animation version

This commit is contained in:
Bibin Muttappillil 2021-07-05 15:30:20 +02:00
parent 271e150077
commit a655236529
3 changed files with 105 additions and 85 deletions

View File

@ -1,92 +1,112 @@
from manim import * from manim import *
class Sequence(VMobject): class Main(Scene):
# evenly spaced characters def construct(self):
n = 13
k = 8
def __init__(self, seq): def check(x):
VMobject.__init__(self)
seq = [t.move_to(i * 0.5 * RIGHT) for i, t in enumerate(seq)]
self.add(VGroup(*seq).move_to(ORIGIN + 2*UP))
class Pointer(VMobject):
# pointer to sequence element
def __init__(self, pointee, name):
VMobject.__init__(self)
arrow = Arrow(ORIGIN, DOWN).next_to(pointee, UP)
name_tex = Tex(name).next_to(arrow, UP)
self.add(arrow, name_tex)
class ProgramCounter(VMobject):
# pointer to current code line
def __init__(self, line):
VMobject.__init__(self)
self.add(Arrow(ORIGIN, RIGHT).next_to(line, LEFT))
class Base(Scene):
def construct(self, n, k):
# initial situation
seq = [Tex("1")] + [Tex("?") for _ in range(n-2)] + [Tex("0")]
self.play(Create(Sequence(seq)))
# source code binary search
code = Code(file_name="binary-search.py", language="Python", insert_line_no=False, style='monokai')
code.move_to(DOWN)
self.play(Write(code))
line_lr, line_while, line_m, line_test, line_l, _, line_r = code[2]
def test(x):
return 1 if x <= k else 0 return 1 if x <= k else 0
def update_pc(pc, line): # initial situation
npc = ProgramCounter(line) seq = Matrix([[1] + ["?" for _ in range(n-1)] + [0]], h_buff=1.0).elements.move_to(0.5*UP)
self.play(ReplacementTransform(pc, npc)) self.play(Create(seq))
return npc
# source code binary search
# while(high - low > 1){ // line_while
# int mid = (low + high) / 2; // line_mid
# if(check(mid)) // line_check
# low = mid; // line_low
# else
# high = mid; // line_high
# }
code = Code(file_name="binary-search.cpp", language="c++", insert_line_no=False, style='monokai', tab_width=4, line_spacing=0.5).code
code.move_to(4*RIGHT + 2.5*DOWN)
while_line, mid_line, check_line, low_line, __, high_line, __ = code
# check code
check_code = Code(code="check(x)", language="c++", insert_line_no=False, style='monokai', tab_width=4, line_spacing=0.5).code
check_background, check_code = check_code.move_to(2.5*UP).add_background_rectangle(opacity=1.0)
self.play(Create(check_background), Write(check_code), Write(code))
def pointer(name):
arrow = Arrow(start=ORIGIN, end=UP)
label = Tex(name).next_to(arrow, DOWN)
return VGroup(arrow, label).scale(0.7)
# initialize borders # initialize borders
pc = ProgramCounter(line_lr) low = 0
self.play(Create(pc)) high = n
l, r = 0, n-1 low_pointer = pointer("low").next_to(seq[low], DOWN)
lp = Pointer(seq[l], "l") high_pointer = pointer("high").next_to(seq[high], DOWN)
rp = Pointer(seq[r], "r") self.play(Write(low_pointer), Create(high_pointer))
self.play(Write(lp), Create(rp))
while r - l > 1: mid_pointer = pointer("mid")
# calculate m
pc = update_pc(pc, line_m)
m = (l + r) // 2
mp = Pointer(seq[m], "m")
self.play(Create(mp))
# test def indicate_start(line):
pc = update_pc(pc, line_test) line.save_state()
old = seq.copy() self.play(line.animate.scale(1.3).set_color(YELLOW))
seq[m] = Tex(str(test(m)))
self.play(ReplacementTransform(Sequence(old), Sequence(seq)))
if test(m): def indicate_end(line):
# update left self.play(Restore(line))
pc = update_pc(pc, line_l)
l = m def high_low_test():
op = lp indicate_start(while_line)
lp = Pointer(seq[m], "l") brace = BraceText(VGroup(*seq[low:high]), "high - low = " + str(high - low), brace_direction=UP)
self.play(Uncreate(mp), ReplacementTransform(op, lp)) brace.label.scale(0.7)
self.play(FadeIn(brace))
self.wait(1)
self.play(FadeOut(brace))
indicate_end(while_line)
while high_low_test() or high - low > 1:
# calculate mid
indicate_start(mid_line)
mid = (low + high) // 2
mid_pointer.next_to(seq[mid], DOWN)
self.play(Write(mid_pointer))
indicate_end(mid_line)
# check
indicate_start(check_line)
seq[mid].add_background_rectangle(opacity=1.0)
path = Line(start=seq[mid].get_center(), end=check_background.get_center())
rpath = Line(start=path.get_end(), end=path.get_start())
laser = path.copy().set_length(0.3)
self.bring_to_back(laser)
self.play(MoveAlongPath(laser, path))
color = GREEN if (ans := check(mid)) else RED
self.play(Indicate(check_code, color=color))
self.bring_to_back(laser)
self.play(MoveAlongPath(laser, rpath))
self.remove(laser)
# TODO: mark tested
range_seq = seq[low+1:mid+1] if ans else seq[mid:high]
self.play(*[Transform(s, Tex(ans).move_to(s)) for s in range_seq])
indicate_end(check_line)
if check(mid):
# update low
indicate_start(low_line)
low = mid
low_pointer.generate_target().move_to(mid_pointer)
self.play(FadeOut(mid_pointer), MoveToTarget(low_pointer))
indicate_end(low_line)
else: else:
# update right # update high
pc = update_pc(pc, line_r) indicate_start(high_line)
r = m high = mid
op = rp high_pointer.generate_target().move_to(mid_pointer)
rp = Pointer(seq[m], "r") self.play(FadeOut(mid_pointer), MoveToTarget(high_pointer))
self.play(Uncreate(mp), ReplacementTransform(op, rp)) indicate_end(high_line)
self.wait(2) self.wait(2)
class Main(Base):
def construct(self):
Base.construct(self, 20, 6)

View File

@ -0,0 +1,7 @@
while(high - low > 1){
int mid = (low + high) / 2;
if(check(mid))
low = mid;
else
high = mid;
}

View File

@ -1,7 +0,0 @@
l, r = 0, n-1
while r - l > 1:
m = (l + r) / 2
if test(m):
l = m
else:
r = m