improved binary search animation version
This commit is contained in:
parent
271e150077
commit
a655236529
|
@ -1,92 +1,112 @@
|
|||
from manim import *
|
||||
|
||||
|
||||
class Sequence(VMobject):
|
||||
# evenly spaced characters
|
||||
class Main(Scene):
|
||||
def construct(self):
|
||||
n = 13
|
||||
k = 8
|
||||
|
||||
def __init__(self, seq):
|
||||
VMobject.__init__(self)
|
||||
seq = [t.move_to(i * 0.5 * RIGHT) for i, t in enumerate(seq)]
|
||||
self.add(VGroup(*seq).move_to(ORIGIN + 2*UP))
|
||||
|
||||
|
||||
class Pointer(VMobject):
|
||||
# pointer to sequence element
|
||||
|
||||
def __init__(self, pointee, name):
|
||||
VMobject.__init__(self)
|
||||
arrow = Arrow(ORIGIN, DOWN).next_to(pointee, UP)
|
||||
name_tex = Tex(name).next_to(arrow, UP)
|
||||
self.add(arrow, name_tex)
|
||||
|
||||
|
||||
class ProgramCounter(VMobject):
|
||||
# pointer to current code line
|
||||
|
||||
def __init__(self, line):
|
||||
VMobject.__init__(self)
|
||||
self.add(Arrow(ORIGIN, RIGHT).next_to(line, LEFT))
|
||||
|
||||
|
||||
class Base(Scene):
|
||||
def construct(self, n, k):
|
||||
|
||||
# initial situation
|
||||
seq = [Tex("1")] + [Tex("?") for _ in range(n-2)] + [Tex("0")]
|
||||
self.play(Create(Sequence(seq)))
|
||||
|
||||
# source code binary search
|
||||
code = Code(file_name="binary-search.py", language="Python", insert_line_no=False, style='monokai')
|
||||
code.move_to(DOWN)
|
||||
self.play(Write(code))
|
||||
line_lr, line_while, line_m, line_test, line_l, _, line_r = code[2]
|
||||
|
||||
def test(x):
|
||||
def check(x):
|
||||
return 1 if x <= k else 0
|
||||
|
||||
def update_pc(pc, line):
|
||||
npc = ProgramCounter(line)
|
||||
self.play(ReplacementTransform(pc, npc))
|
||||
return npc
|
||||
# initial situation
|
||||
seq = Matrix([[1] + ["?" for _ in range(n-1)] + [0]], h_buff=1.0).elements.move_to(0.5*UP)
|
||||
self.play(Create(seq))
|
||||
|
||||
# source code binary search
|
||||
|
||||
# while(high - low > 1){ // line_while
|
||||
# int mid = (low + high) / 2; // line_mid
|
||||
# if(check(mid)) // line_check
|
||||
# low = mid; // line_low
|
||||
# else
|
||||
# high = mid; // line_high
|
||||
# }
|
||||
code = Code(file_name="binary-search.cpp", language="c++", insert_line_no=False, style='monokai', tab_width=4, line_spacing=0.5).code
|
||||
code.move_to(4*RIGHT + 2.5*DOWN)
|
||||
while_line, mid_line, check_line, low_line, __, high_line, __ = code
|
||||
|
||||
# check code
|
||||
check_code = Code(code="check(x)", language="c++", insert_line_no=False, style='monokai', tab_width=4, line_spacing=0.5).code
|
||||
check_background, check_code = check_code.move_to(2.5*UP).add_background_rectangle(opacity=1.0)
|
||||
self.play(Create(check_background), Write(check_code), Write(code))
|
||||
|
||||
def pointer(name):
|
||||
arrow = Arrow(start=ORIGIN, end=UP)
|
||||
label = Tex(name).next_to(arrow, DOWN)
|
||||
return VGroup(arrow, label).scale(0.7)
|
||||
|
||||
# initialize borders
|
||||
pc = ProgramCounter(line_lr)
|
||||
self.play(Create(pc))
|
||||
l, r = 0, n-1
|
||||
lp = Pointer(seq[l], "l")
|
||||
rp = Pointer(seq[r], "r")
|
||||
self.play(Write(lp), Create(rp))
|
||||
low = 0
|
||||
high = n
|
||||
low_pointer = pointer("low").next_to(seq[low], DOWN)
|
||||
high_pointer = pointer("high").next_to(seq[high], DOWN)
|
||||
self.play(Write(low_pointer), Create(high_pointer))
|
||||
|
||||
while r - l > 1:
|
||||
# calculate m
|
||||
pc = update_pc(pc, line_m)
|
||||
m = (l + r) // 2
|
||||
mp = Pointer(seq[m], "m")
|
||||
self.play(Create(mp))
|
||||
mid_pointer = pointer("mid")
|
||||
|
||||
# test
|
||||
pc = update_pc(pc, line_test)
|
||||
old = seq.copy()
|
||||
seq[m] = Tex(str(test(m)))
|
||||
self.play(ReplacementTransform(Sequence(old), Sequence(seq)))
|
||||
def indicate_start(line):
|
||||
line.save_state()
|
||||
self.play(line.animate.scale(1.3).set_color(YELLOW))
|
||||
|
||||
if test(m):
|
||||
# update left
|
||||
pc = update_pc(pc, line_l)
|
||||
l = m
|
||||
op = lp
|
||||
lp = Pointer(seq[m], "l")
|
||||
self.play(Uncreate(mp), ReplacementTransform(op, lp))
|
||||
def indicate_end(line):
|
||||
self.play(Restore(line))
|
||||
|
||||
def high_low_test():
|
||||
indicate_start(while_line)
|
||||
brace = BraceText(VGroup(*seq[low:high]), "high - low = " + str(high - low), brace_direction=UP)
|
||||
brace.label.scale(0.7)
|
||||
self.play(FadeIn(brace))
|
||||
self.wait(1)
|
||||
self.play(FadeOut(brace))
|
||||
indicate_end(while_line)
|
||||
|
||||
while high_low_test() or high - low > 1:
|
||||
|
||||
# calculate mid
|
||||
indicate_start(mid_line)
|
||||
mid = (low + high) // 2
|
||||
mid_pointer.next_to(seq[mid], DOWN)
|
||||
self.play(Write(mid_pointer))
|
||||
indicate_end(mid_line)
|
||||
|
||||
# check
|
||||
indicate_start(check_line)
|
||||
|
||||
seq[mid].add_background_rectangle(opacity=1.0)
|
||||
path = Line(start=seq[mid].get_center(), end=check_background.get_center())
|
||||
rpath = Line(start=path.get_end(), end=path.get_start())
|
||||
laser = path.copy().set_length(0.3)
|
||||
|
||||
self.bring_to_back(laser)
|
||||
self.play(MoveAlongPath(laser, path))
|
||||
|
||||
color = GREEN if (ans := check(mid)) else RED
|
||||
self.play(Indicate(check_code, color=color))
|
||||
|
||||
self.bring_to_back(laser)
|
||||
self.play(MoveAlongPath(laser, rpath))
|
||||
self.remove(laser)
|
||||
|
||||
# TODO: mark tested
|
||||
range_seq = seq[low+1:mid+1] if ans else seq[mid:high]
|
||||
self.play(*[Transform(s, Tex(ans).move_to(s)) for s in range_seq])
|
||||
|
||||
indicate_end(check_line)
|
||||
|
||||
if check(mid):
|
||||
# update low
|
||||
indicate_start(low_line)
|
||||
low = mid
|
||||
low_pointer.generate_target().move_to(mid_pointer)
|
||||
self.play(FadeOut(mid_pointer), MoveToTarget(low_pointer))
|
||||
indicate_end(low_line)
|
||||
else:
|
||||
# update right
|
||||
pc = update_pc(pc, line_r)
|
||||
r = m
|
||||
op = rp
|
||||
rp = Pointer(seq[m], "r")
|
||||
self.play(Uncreate(mp), ReplacementTransform(op, rp))
|
||||
# update high
|
||||
indicate_start(high_line)
|
||||
high = mid
|
||||
high_pointer.generate_target().move_to(mid_pointer)
|
||||
self.play(FadeOut(mid_pointer), MoveToTarget(high_pointer))
|
||||
indicate_end(high_line)
|
||||
|
||||
self.wait(2)
|
||||
|
||||
class Main(Base):
|
||||
def construct(self):
|
||||
Base.construct(self, 20, 6)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
while(high - low > 1){
|
||||
int mid = (low + high) / 2;
|
||||
if(check(mid))
|
||||
low = mid;
|
||||
else
|
||||
high = mid;
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
l, r = 0, n-1
|
||||
while r - l > 1:
|
||||
m = (l + r) / 2
|
||||
if test(m):
|
||||
l = m
|
||||
else:
|
||||
r = m
|
Loading…
Reference in New Issue