improved binary search animation version
This commit is contained in:
parent
271e150077
commit
a655236529
|
@ -1,92 +1,112 @@
|
||||||
from manim import *
|
from manim import *
|
||||||
|
|
||||||
|
|
||||||
class Sequence(VMobject):
|
class Main(Scene):
|
||||||
# evenly spaced characters
|
def construct(self):
|
||||||
|
n = 13
|
||||||
|
k = 8
|
||||||
|
|
||||||
def __init__(self, seq):
|
def check(x):
|
||||||
VMobject.__init__(self)
|
|
||||||
seq = [t.move_to(i * 0.5 * RIGHT) for i, t in enumerate(seq)]
|
|
||||||
self.add(VGroup(*seq).move_to(ORIGIN + 2*UP))
|
|
||||||
|
|
||||||
|
|
||||||
class Pointer(VMobject):
|
|
||||||
# pointer to sequence element
|
|
||||||
|
|
||||||
def __init__(self, pointee, name):
|
|
||||||
VMobject.__init__(self)
|
|
||||||
arrow = Arrow(ORIGIN, DOWN).next_to(pointee, UP)
|
|
||||||
name_tex = Tex(name).next_to(arrow, UP)
|
|
||||||
self.add(arrow, name_tex)
|
|
||||||
|
|
||||||
|
|
||||||
class ProgramCounter(VMobject):
|
|
||||||
# pointer to current code line
|
|
||||||
|
|
||||||
def __init__(self, line):
|
|
||||||
VMobject.__init__(self)
|
|
||||||
self.add(Arrow(ORIGIN, RIGHT).next_to(line, LEFT))
|
|
||||||
|
|
||||||
|
|
||||||
class Base(Scene):
|
|
||||||
def construct(self, n, k):
|
|
||||||
|
|
||||||
# initial situation
|
|
||||||
seq = [Tex("1")] + [Tex("?") for _ in range(n-2)] + [Tex("0")]
|
|
||||||
self.play(Create(Sequence(seq)))
|
|
||||||
|
|
||||||
# source code binary search
|
|
||||||
code = Code(file_name="binary-search.py", language="Python", insert_line_no=False, style='monokai')
|
|
||||||
code.move_to(DOWN)
|
|
||||||
self.play(Write(code))
|
|
||||||
line_lr, line_while, line_m, line_test, line_l, _, line_r = code[2]
|
|
||||||
|
|
||||||
def test(x):
|
|
||||||
return 1 if x <= k else 0
|
return 1 if x <= k else 0
|
||||||
|
|
||||||
def update_pc(pc, line):
|
# initial situation
|
||||||
npc = ProgramCounter(line)
|
seq = Matrix([[1] + ["?" for _ in range(n-1)] + [0]], h_buff=1.0).elements.move_to(0.5*UP)
|
||||||
self.play(ReplacementTransform(pc, npc))
|
self.play(Create(seq))
|
||||||
return npc
|
|
||||||
|
# source code binary search
|
||||||
|
|
||||||
|
# while(high - low > 1){ // line_while
|
||||||
|
# int mid = (low + high) / 2; // line_mid
|
||||||
|
# if(check(mid)) // line_check
|
||||||
|
# low = mid; // line_low
|
||||||
|
# else
|
||||||
|
# high = mid; // line_high
|
||||||
|
# }
|
||||||
|
code = Code(file_name="binary-search.cpp", language="c++", insert_line_no=False, style='monokai', tab_width=4, line_spacing=0.5).code
|
||||||
|
code.move_to(4*RIGHT + 2.5*DOWN)
|
||||||
|
while_line, mid_line, check_line, low_line, __, high_line, __ = code
|
||||||
|
|
||||||
|
# check code
|
||||||
|
check_code = Code(code="check(x)", language="c++", insert_line_no=False, style='monokai', tab_width=4, line_spacing=0.5).code
|
||||||
|
check_background, check_code = check_code.move_to(2.5*UP).add_background_rectangle(opacity=1.0)
|
||||||
|
self.play(Create(check_background), Write(check_code), Write(code))
|
||||||
|
|
||||||
|
def pointer(name):
|
||||||
|
arrow = Arrow(start=ORIGIN, end=UP)
|
||||||
|
label = Tex(name).next_to(arrow, DOWN)
|
||||||
|
return VGroup(arrow, label).scale(0.7)
|
||||||
|
|
||||||
# initialize borders
|
# initialize borders
|
||||||
pc = ProgramCounter(line_lr)
|
low = 0
|
||||||
self.play(Create(pc))
|
high = n
|
||||||
l, r = 0, n-1
|
low_pointer = pointer("low").next_to(seq[low], DOWN)
|
||||||
lp = Pointer(seq[l], "l")
|
high_pointer = pointer("high").next_to(seq[high], DOWN)
|
||||||
rp = Pointer(seq[r], "r")
|
self.play(Write(low_pointer), Create(high_pointer))
|
||||||
self.play(Write(lp), Create(rp))
|
|
||||||
|
|
||||||
while r - l > 1:
|
mid_pointer = pointer("mid")
|
||||||
# calculate m
|
|
||||||
pc = update_pc(pc, line_m)
|
|
||||||
m = (l + r) // 2
|
|
||||||
mp = Pointer(seq[m], "m")
|
|
||||||
self.play(Create(mp))
|
|
||||||
|
|
||||||
# test
|
def indicate_start(line):
|
||||||
pc = update_pc(pc, line_test)
|
line.save_state()
|
||||||
old = seq.copy()
|
self.play(line.animate.scale(1.3).set_color(YELLOW))
|
||||||
seq[m] = Tex(str(test(m)))
|
|
||||||
self.play(ReplacementTransform(Sequence(old), Sequence(seq)))
|
|
||||||
|
|
||||||
if test(m):
|
def indicate_end(line):
|
||||||
# update left
|
self.play(Restore(line))
|
||||||
pc = update_pc(pc, line_l)
|
|
||||||
l = m
|
def high_low_test():
|
||||||
op = lp
|
indicate_start(while_line)
|
||||||
lp = Pointer(seq[m], "l")
|
brace = BraceText(VGroup(*seq[low:high]), "high - low = " + str(high - low), brace_direction=UP)
|
||||||
self.play(Uncreate(mp), ReplacementTransform(op, lp))
|
brace.label.scale(0.7)
|
||||||
|
self.play(FadeIn(brace))
|
||||||
|
self.wait(1)
|
||||||
|
self.play(FadeOut(brace))
|
||||||
|
indicate_end(while_line)
|
||||||
|
|
||||||
|
while high_low_test() or high - low > 1:
|
||||||
|
|
||||||
|
# calculate mid
|
||||||
|
indicate_start(mid_line)
|
||||||
|
mid = (low + high) // 2
|
||||||
|
mid_pointer.next_to(seq[mid], DOWN)
|
||||||
|
self.play(Write(mid_pointer))
|
||||||
|
indicate_end(mid_line)
|
||||||
|
|
||||||
|
# check
|
||||||
|
indicate_start(check_line)
|
||||||
|
|
||||||
|
seq[mid].add_background_rectangle(opacity=1.0)
|
||||||
|
path = Line(start=seq[mid].get_center(), end=check_background.get_center())
|
||||||
|
rpath = Line(start=path.get_end(), end=path.get_start())
|
||||||
|
laser = path.copy().set_length(0.3)
|
||||||
|
|
||||||
|
self.bring_to_back(laser)
|
||||||
|
self.play(MoveAlongPath(laser, path))
|
||||||
|
|
||||||
|
color = GREEN if (ans := check(mid)) else RED
|
||||||
|
self.play(Indicate(check_code, color=color))
|
||||||
|
|
||||||
|
self.bring_to_back(laser)
|
||||||
|
self.play(MoveAlongPath(laser, rpath))
|
||||||
|
self.remove(laser)
|
||||||
|
|
||||||
|
# TODO: mark tested
|
||||||
|
range_seq = seq[low+1:mid+1] if ans else seq[mid:high]
|
||||||
|
self.play(*[Transform(s, Tex(ans).move_to(s)) for s in range_seq])
|
||||||
|
|
||||||
|
indicate_end(check_line)
|
||||||
|
|
||||||
|
if check(mid):
|
||||||
|
# update low
|
||||||
|
indicate_start(low_line)
|
||||||
|
low = mid
|
||||||
|
low_pointer.generate_target().move_to(mid_pointer)
|
||||||
|
self.play(FadeOut(mid_pointer), MoveToTarget(low_pointer))
|
||||||
|
indicate_end(low_line)
|
||||||
else:
|
else:
|
||||||
# update right
|
# update high
|
||||||
pc = update_pc(pc, line_r)
|
indicate_start(high_line)
|
||||||
r = m
|
high = mid
|
||||||
op = rp
|
high_pointer.generate_target().move_to(mid_pointer)
|
||||||
rp = Pointer(seq[m], "r")
|
self.play(FadeOut(mid_pointer), MoveToTarget(high_pointer))
|
||||||
self.play(Uncreate(mp), ReplacementTransform(op, rp))
|
indicate_end(high_line)
|
||||||
|
|
||||||
self.wait(2)
|
self.wait(2)
|
||||||
|
|
||||||
class Main(Base):
|
|
||||||
def construct(self):
|
|
||||||
Base.construct(self, 20, 6)
|
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
while(high - low > 1){
|
||||||
|
int mid = (low + high) / 2;
|
||||||
|
if(check(mid))
|
||||||
|
low = mid;
|
||||||
|
else
|
||||||
|
high = mid;
|
||||||
|
}
|
|
@ -1,7 +0,0 @@
|
||||||
l, r = 0, n-1
|
|
||||||
while r - l > 1:
|
|
||||||
m = (l + r) / 2
|
|
||||||
if test(m):
|
|
||||||
l = m
|
|
||||||
else:
|
|
||||||
r = m
|
|
Loading…
Reference in New Issue