Skip to content
Snippets Groups Projects
pbar_bessy.py 7.83 KiB
Newer Older
Simone Vadilonga's avatar
Simone Vadilonga committed
import sys
Simone Vadilonga's avatar
Simone Vadilonga committed
from typing import Any, Optional, Callable
Simone Vadilonga's avatar
Simone Vadilonga committed
import abc
import numpy as np
Simone Vadilonga's avatar
Simone Vadilonga committed
from functools import partial
Simone Vadilonga's avatar
Simone Vadilonga committed
import threading
import time
from tqdm import tqdm
from tqdm.utils import _screen_shape_wrapper, _term_move_up, _unicode
import warnings

Simone Vadilonga's avatar
Simone Vadilonga committed
from bluesky.utils import _L2norm
Simone Vadilonga's avatar
Simone Vadilonga committed


class ProgressBarBase(abc.ABC):
    
    def update(
            self,
            pos: Any,
            *,
            name: str = None,
            current: Any = None,
            initial: Any = None,
            target: Any = None,
            unit: str = "units",
            precision: Any = None,
            fraction: Any = None,
            time_elapsed: float = None,
            time_remaining: float = None,
Simone Vadilonga's avatar
Simone Vadilonga committed
            motor_pos: str = None # here we add the motor position
Simone Vadilonga's avatar
Simone Vadilonga committed
    ):
        ...

    def clear(self):
        ...


class TerminalProgressBar(ProgressBarBase):
    def __init__(self, status_objs, delay_draw=0.2):
        """
        Represent status objects with a progress bars.

        Parameters
        ----------
        status_objs : list
            Status objects
        delay_draw : float, optional
            To avoid flashing progress bars that will complete quickly after
            they are displayed, delay drawing until the progress bar has been
            around for awhile. Default is 0.2 seconds.
        """
        self.meters = []
        self.status_objs = []
        # Determine terminal width.
        self.ncols = _screen_shape_wrapper()(sys.stdout)[0] or 79
        self.fp = sys.stdout
        self.creation_time = time.time()
        self.delay_draw = delay_draw
        self.drawn = False
        self.done = False
        self.lock = threading.RLock()
        self.motors = False
        self.user_ns = False

        # If the ProgressBar is not finished before the delay_draw time but
        # never again updated after the delay_draw time, we need to draw it
        # once.
        if delay_draw:
            threading.Thread(target=self._ensure_draw, daemon=True).start()

        # Create a closure over self.update for each status object that
        # implemets the 'watch' method.
        for st in status_objs:
            with self.lock:
                if hasattr(st, 'watch') and not st.done:
                    pos = len(self.meters)
                    self.meters.append('')
                    self.status_objs.append(st)
                    st.watch(partial(self.update, pos))
    
    def update_shell_user_ns(self, user_ns):
        self.user_ns = user_ns

    def update(self, pos, *,
               name=None,
               current=None, initial=None, target=None,
               unit='units', precision=None,
               fraction=None,
               time_elapsed=None, time_remaining=None,
               ):
        if all(x is not None for x in (current, initial, target)):
            # Display a proper progress bar.
            total = round(_L2norm(target, initial), precision or 3)
            # make sure we ignore overshoot to prevent tqdm from exploding.
            n = np.clip(round(_L2norm(current, initial), precision or 3), 0, total)
            # Compute this only if the status object did not provide it.
            if time_elapsed is None:
                time_elapsed = time.time() - self.creation_time
            # TODO Account for 'fraction', which might in some special cases
            # differ from the naive computation above.
            # TODO Account for 'time_remaining' which might in some special
            # cases differ from the naive computaiton performed by
            # format_meter

            #ask the motor position, if smt goes wrong print "n.a."
            try:
Marcel Bajdel's avatar
Marcel Bajdel committed
                axis=name.replace("_", ".")
                command = f"asyncio.run({axis}.read())"
                print(f"pbar command: {command}")
                motor_position = np.round(eval(command, self.user_ns),3)
Simone Vadilonga's avatar
Simone Vadilonga committed
            except:
Simone Vadilonga's avatar
Simone Vadilonga committed
                try:
                    motor_name=name.replace("_", ".")+".user_readback.get()"
                    motor_position = np.round(eval(motor_name, self.user_ns),3)
                except:
                    motor_position = "n.a."
            try:
                motor_name=name.replace("_", ".")+".egu"
                egu = eval(motor_name, self.user_ns)
            except:
                egu = 'n.a.'
Simone Vadilonga's avatar
Simone Vadilonga committed
            meter = tqdm.format_meter(n=n, total=total, elapsed=time_elapsed,
Simone Vadilonga's avatar
Simone Vadilonga committed
                                    unit=unit,
                                    prefix=name,
                                    ncols=self.ncols,
                                    postfix =  f'- {motor_position}{egu}', 
                                    )
Simone Vadilonga's avatar
Simone Vadilonga committed
        else:
            # Simply display completeness.
            if name is None:
                name = ''
            if self.status_objs[pos].done:
                meter = name + ' [Complete.]'
            else:
                meter = name + ' [In progress. No progress bar available.]'
            meter += ' ' * (self.ncols - len(meter))
            meter = meter[:self.ncols]

        self.meters[pos] = meter
        self.draw()

    def draw(self):
        with self.lock:
            if (time.time() - self.creation_time) < self.delay_draw:
                return
            if self.done:
                return
            for meter in self.meters:
                tqdm.status_printer(self.fp)(meter)
                self.fp.write('\n')
            self.fp.write(_unicode(_term_move_up() * len(self.meters)))
            self.drawn = True

    def _ensure_draw(self):
        # Ensure that the progress bar is drawn at least once after the delay.
        time.sleep(self.delay_draw)
        with self.lock:
            if (not self.done) and (not self.drawn):
                self.draw()

    def clear(self):
        with self.lock:
            self.done = True
            if self.drawn:
                for meter in self.meters:
                    self.fp.write('\r')
                    self.fp.write(' ' * self.ncols)
                    self.fp.write('\r')
                    self.fp.write('\n')
                self.fp.write(_unicode(_term_move_up() * len(self.meters)))


class ProgressBar(TerminalProgressBar):
    """
    Alias for backwards compatibility
    """

    ...


def default_progress_bar(status_objs_or_none) -> ProgressBarBase:
    return TerminalProgressBar(status_objs_or_none, delay_draw=0.2)


class ProgressBarManager:
    pbar_factory: Callable[[Any], ProgressBarBase]
    pbar: Optional[ProgressBarBase]

    def __init__(self,
                 pbar_factory: Callable[[Any], ProgressBarBase] = default_progress_bar):
        """
        Manages creation and tearing down of progress bars.

        Parameters
        ----------
        pbar_factory : Callable[[Any], ProgressBar], optional
            A function that creates a progress bar given an optional list of status objects,
            by default default_progress_bar
        """

        self.pbar_factory = pbar_factory
        self.pbar = None
        self.user_ns = None

    def __call__(self, status_objs_or_none):
        """
        Updates the manager with a new set of status, creates a new progress bar and
        cleans up the old one if needed.

        Parameters
        ----------
        status_objs_or_none : Set[Status], optional
            Optional list of status objects to be passed to the factory.
        """
        if status_objs_or_none is not None:
            # Start a new ProgressBar.
            if self.pbar is not None:
                warnings.warn("Previous ProgressBar never competed.")
                self.pbar.clear()
            self.pbar = self.pbar_factory(status_objs_or_none)
            self.pbar.update_shell_user_ns(self.user_ns)
        else:
            # Clean up an old one.
            if self.pbar is None:
                warnings.warn("There is no Progress bar to clean up.")
            else:
                self.pbar.clear()
                self.pbar = None