Source code for astroquery.query

# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)
import abc
import inspect
import pickle
import copy
import getpass
import hashlib
import keyring
import io
import os
import requests
import textwrap

from astropy.config import paths
from astroquery import log
import astropy.units as u
from astropy.utils.console import ProgressBarOrSpinner
import astropy.utils.data

from . import version
from .utils import system_tools

__all__ = ['BaseQuery', 'QueryWithLogin']


def to_cache(response, cache_file):
    log.debug("Caching data to {0}".format(cache_file))
    response = copy.deepcopy(response)
    if hasattr(response, 'request'):
        for key in tuple(response.request.hooks.keys()):
            del response.request.hooks[key]
    with open(cache_file, "wb") as f:
        pickle.dump(response, f)


def _replace_none_iterable(iterable):
    return tuple('' if i is None else i for i in iterable)


class AstroQuery:

    def __init__(self, method, url,
                 params=None, data=None, headers=None,
                 files=None, timeout=None, json=None):
        self.method = method
        self.url = url
        self.params = params
        self.data = data
        self.json = json
        self.headers = headers
        self.files = files
        self._hash = None
        self.timeout = timeout

    @property
    def timeout(self):
        return self._timeout

    @timeout.setter
    def timeout(self, value):
        if hasattr(value, 'to'):
            self._timeout = value.to(u.s).value
        else:
            self._timeout = value

    def request(self, session, cache_location=None, stream=False,
                auth=None, verify=True, allow_redirects=True,
                json=None):
        return session.request(self.method, self.url, params=self.params,
                               data=self.data, headers=self.headers,
                               files=self.files, timeout=self.timeout,
                               stream=stream, auth=auth, verify=verify,
                               allow_redirects=allow_redirects,
                               json=json)

    def hash(self):
        if self._hash is None:
            request_key = (self.method, self.url)
            for k in (self.params, self.data, self.json,
                      self.headers, self.files):
                if isinstance(k, dict):
                    entry = (tuple(sorted(k.items(),
                                          key=_replace_none_iterable)))
                    entry = tuple((k_, v_.read()) if hasattr(v_, 'read')
                                  else (k_, v_) for k_, v_ in entry)
                    for k_, v_ in entry:
                        if hasattr(v_, 'read') and hasattr(v_, 'seek'):
                            v_.seek(0)

                    request_key += entry
                elif isinstance(k, tuple) or isinstance(k, list):
                    request_key += (tuple(sorted(k,
                                                 key=_replace_none_iterable)),)
                elif k is None:
                    request_key += (None,)
                elif isinstance(k, str):
                    request_key += (k,)
                else:
                    raise TypeError("{0} must be a dict, tuple, str, or "
                                    "list".format(k))
            self._hash = hashlib.sha224(pickle.dumps(request_key)).hexdigest()
        return self._hash

    def request_file(self, cache_location):
        fn = os.path.join(cache_location, self.hash() + ".pickle")
        return fn

    def from_cache(self, cache_location):
        request_file = self.request_file(cache_location)
        try:
            with open(request_file, "rb") as f:
                response = pickle.load(f)
            if not isinstance(response, requests.Response):
                response = None
        except IOError:  # TODO: change to FileNotFoundError once drop py2 support
            response = None
        if response:
            log.debug("Retrieving data from {0}".format(request_file))
        return response

    def remove_cache_file(self, cache_location):
        """
        Remove the cache file - may be needed if a query fails during parsing
        (successful request, but failed return)
        """
        request_file = self.request_file(cache_location)

        if os.path.exists(request_file):
            os.remove(request_file)
        else:
            raise OSError(f"Tried to remove cache file {request_file} but "
                          "it does not exist")


class LoginABCMeta(abc.ABCMeta):
    """
    The goal of this metaclass is to copy the docstring and signature from
    ._login methods, implemented in subclasses, to a .login method that is
    visible by the users.

    It also inherits from the ABCMeta metaclass as _login is an abstract
    method.

    """

    def __new__(cls, name, bases, attrs):
        newcls = super(LoginABCMeta, cls).__new__(cls, name, bases, attrs)

        if '_login' in attrs and name not in ('BaseQuery', 'QueryWithLogin'):
            # skip theses two classes, BaseQuery and QueryWithLogin, so
            # below bases[0] should always be QueryWithLogin.
            def login(*args, **kwargs):
                bases[0].login(*args, **kwargs)

            login.__doc__ = attrs['_login'].__doc__
            login.__signature__ = inspect.signature(attrs['_login'])
            setattr(newcls, login.__name__, login)

        return newcls


[docs]class BaseQuery(metaclass=LoginABCMeta): """ This is the base class for all the query classes in astroquery. It is implemented as an abstract class and must not be directly instantiated. """ def __init__(self): S = self._session = requests.Session() self._session.hooks['response'].append(self._response_hook) S.headers['User-Agent'] = ( 'astroquery/{vers} {olduseragent}' .format(vers=version.version, olduseragent=S.headers['User-Agent'])) self.cache_location = os.path.join( paths.get_cache_dir(), 'astroquery', self.__class__.__name__.split("Class")[0]) os.makedirs(self.cache_location, exist_ok=True) self._cache_active = True
[docs] def __call__(self, *args, **kwargs): """ init a fresh copy of self """ return self.__class__(*args, **kwargs)
def _response_hook(self, response, *args, **kwargs): loglevel = log.getEffectiveLevel() if loglevel >= 10: # Log request at DEBUG severity request_hdrs = '\n'.join(f'{k}: {v}' for k, v in response.request.headers.items()) request_log = textwrap.indent( f"-----------------------------------------\n" f"{response.request.method} {response.request.url}\n" f"{request_hdrs}\n" f"\n" f"{response.request.body}\n" f"-----------------------------------------", '\t') log.debug(f"HTTP request\n{request_log}") if loglevel >= 5: # Log response at super-DEBUG severity response_hdrs = '\n'.join(f'{k}: {v}' for k, v in response.headers.items()) if kwargs.get('stream'): response_log = textwrap.indent( f"-----------------------------------------\n" f"{response.status_code} {response.reason} {response.url}\n" f"{response_hdrs}\n" "Streaming Data\n" f"-----------------------------------------", '\t') else: response_log = textwrap.indent( f"-----------------------------------------\n" f"{response.status_code} {response.reason} {response.url}\n" f"{response_hdrs}\n" f"\n" f"{response.text}\n" f"-----------------------------------------", '\t') log.log(5, f"HTTP response\n{response_log}") def _request(self, method, url, params=None, data=None, headers=None, files=None, save=False, savedir='', timeout=None, cache=True, stream=False, auth=None, continuation=True, verify=True, allow_redirects=True, json=None, return_response_on_save=False): """ A generic HTTP request method, similar to `requests.Session.request` but with added caching-related tools This is a low-level method not generally intended for use by astroquery end-users. However, it should _always_ be used by astroquery developers; direct uses of `urllib` or `requests` are almost never correct. Parameters ---------- method : str 'GET' or 'POST' url : str params : None or dict data : None or dict json : None or dict headers : None or dict auth : None or dict files : None or dict See `requests.request` save : bool Whether to save the file to a local directory. Caching will happen independent of this parameter if `BaseQuery.cache_location` is set, but the save location can be overridden if ``save==True`` savedir : str The location to save the local file if you want to save it somewhere other than `BaseQuery.cache_location` timeout : int cache : bool verify : bool Verify the server's TLS certificate? (see http://docs.python-requests.org/en/master/_modules/requests/sessions/?highlight=verify) continuation : bool If the file is partly downloaded to the target location, this parameter will try to continue the download where it left off. See `_download_file`. stream : bool return_response_on_save : bool If ``save``, also return the server response. The default is to only return the local file path. Returns ------- response : `requests.Response` The response from the server if ``save`` is False local_filepath : list a list of strings containing the downloaded local paths if ``save`` is True and ``return_response_on_save`` is False. (local_filepath, response) : tuple(list, `requests.Response`) a tuple containing a list of strings containing the downloaded local paths, and the server response object, if ``save`` is True and ``return_response_on_save`` is True. """ req_kwargs = dict( params=params, data=data, headers=headers, files=files, timeout=timeout, json=json ) if save: local_filename = url.split('/')[-1] if os.name == 'nt': # Windows doesn't allow special characters in filenames like # ":" so replace them with an underscore local_filename = local_filename.replace(':', '_') local_filepath = os.path.join(savedir or self.cache_location or '.', local_filename) response = self._download_file(url, local_filepath, cache=cache, continuation=continuation, method=method, allow_redirects=allow_redirects, auth=auth, **req_kwargs) if return_response_on_save: return local_filepath, response else: return local_filepath else: query = AstroQuery(method, url, **req_kwargs) if ((self.cache_location is None) or (not self._cache_active) or (not cache)): with suspend_cache(self): response = query.request(self._session, stream=stream, auth=auth, verify=verify, allow_redirects=allow_redirects, json=json) else: response = query.from_cache(self.cache_location) if not response: response = query.request(self._session, self.cache_location, stream=stream, auth=auth, allow_redirects=allow_redirects, verify=verify, json=json) to_cache(response, query.request_file(self.cache_location)) self._last_query = query return response def _download_file(self, url, local_filepath, timeout=None, auth=None, continuation=True, cache=False, method="GET", head_safe=False, **kwargs): """ Download a file. Resembles `astropy.utils.data.download_file` but uses the local ``_session`` Parameters ---------- url : string local_filepath : string timeout : int auth : dict or None continuation : bool If the file has already been partially downloaded *and* the server supports HTTP "range" requests, the download will be continued where it left off. cache : bool method : "GET" or "POST" head_safe : bool """ if head_safe: response = self._session.request("HEAD", url, timeout=timeout, stream=True, auth=auth, **kwargs) else: response = self._session.request(method, url, timeout=timeout, stream=True, auth=auth, **kwargs) response.raise_for_status() if 'content-length' in response.headers: length = int(response.headers['content-length']) if length == 0: log.warn('URL {0} has length=0'.format(url)) else: length = None if ((os.path.exists(local_filepath) and ('Accept-Ranges' in response.headers) and continuation)): open_mode = 'ab' existing_file_length = os.stat(local_filepath).st_size if length is not None and existing_file_length >= length: # all done! log.info("Found cached file {0} with expected size {1}." .format(local_filepath, existing_file_length)) return elif existing_file_length == 0: open_mode = 'wb' else: log.info("Continuing download of file {0}, with {1} bytes to " "go ({2}%)".format(local_filepath, length - existing_file_length, (length-existing_file_length)/length*100)) # bytes are indexed from 0: # https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#range-request-header end = "{0}".format(length-1) if length is not None else "" self._session.headers['Range'] = "bytes={0}-{1}".format(existing_file_length, end) response = self._session.request(method, url, timeout=timeout, stream=True, auth=auth, **kwargs) response.raise_for_status() del self._session.headers['Range'] elif cache and os.path.exists(local_filepath): if length is not None: statinfo = os.stat(local_filepath) if statinfo.st_size != length: log.warning(f"Found cached file {local_filepath} with size {statinfo.st_size} " f"that is different from expected size {length}") open_mode = 'wb' else: log.info("Found cached file {0} with expected size {1}." .format(local_filepath, statinfo.st_size)) response.close() return else: log.info("Found cached file {0}.".format(local_filepath)) response.close() return else: open_mode = 'wb' if head_safe: response = self._session.request(method, url, timeout=timeout, stream=True, auth=auth, **kwargs) response.raise_for_status() blocksize = astropy.utils.data.conf.download_block_size log.debug(f"Downloading URL {url} to {local_filepath} with size {length} " f"by blocks of {blocksize}") bytes_read = 0 # Only show progress bar if logging level is INFO or lower. if log.getEffectiveLevel() <= 20: progress_stream = None # Astropy default else: progress_stream = io.StringIO() with ProgressBarOrSpinner(length, f'Downloading URL {url} to {local_filepath} ...', file=progress_stream) as pb: with open(local_filepath, open_mode) as f: for block in response.iter_content(blocksize): f.write(block) bytes_read += len(block) if length is not None: pb.update(bytes_read if bytes_read <= length else length) else: pb.update(bytes_read) response.close() return response
class suspend_cache: """ A context manager that suspends caching. """ def __init__(self, obj): self.obj = obj def __enter__(self): self.obj._cache_active = False def __exit__(self, exc_type, exc_value, traceback): self.obj._cache_active = True return False
[docs]class QueryWithLogin(BaseQuery): """ This is the base class for all the query classes which are required to have a login to access the data. The abstract method _login() must be implemented. It is wrapped by the login() method, which turns off the cache. This way, login credentials are not stored in the cache. """ def __init__(self): super(QueryWithLogin, self).__init__() self._authenticated = False def _get_password(self, service_name, username, reenter=False): """Get password from keyring or prompt.""" password_from_keyring = None if reenter is False: try: password_from_keyring = keyring.get_password( service_name, username) except keyring.errors.KeyringError as exc: log.warning("Failed to get a valid keyring for password " "storage: {}".format(exc)) if password_from_keyring is None: log.warning("No password was found in the keychain for the " "provided username.") if system_tools.in_ipynb(): log.warning("You may be using an ipython notebook:" " the password form will appear in your terminal.") password = getpass.getpass("{0}, enter your password:\n" .format(username)) else: password = password_from_keyring return password, password_from_keyring @abc.abstractmethod def _login(self, *args, **kwargs): """ login to non-public data as a known user Parameters ---------- Keyword arguments that can be used to create the data payload(dict) sent via `requests.post` """ pass
[docs] def login(self, *args, **kwargs): with suspend_cache(self): self._authenticated = self._login(*args, **kwargs) return self._authenticated
[docs] def authenticated(self): return self._authenticated