diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 2b3b8fa3e..9b4e842ee 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -6,8 +6,8 @@ name: Publish to PyPi on: workflow_dispatch: push: - branches: - - master + tags: + - 'v*.*.*' jobs: build-n-publish: @@ -19,12 +19,13 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.7 - name: Build dist files run: | python -m pip install --upgrade pip pip install -e .[test] - python setup.py sdist --formats=gztar + python setup.py sdist --formats=gztar bdist_wheel + git describe --tag --dirty --always - name: Publish distribution 📦 to Test PyPI uses: pypa/gh-action-pypi-publish@release/v1 # license BSD-2 with: diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 60a209b61..b83af5a4b 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -29,7 +29,3 @@ jobs: if: always() run: | pytest test - - - name: Run Mypy tests - run: | - mypy --show-error-codes --disable-error-code misc --disable-error-code import tableauserverclient test diff --git a/MANIFEST.in b/MANIFEST.in index c9bb30ee7..9b7512fb9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,13 +1,14 @@ -include versioneer.py -include tableauserverclient/_version.py +include CHANGELOG.md +include contributing.md +include CONTRIBUTORS.md include LICENSE include LICENSE.versioneer include README.md -include CHANGELOG.md +include tableauserverclient/_version.py +include versioneer.py recursive-include docs *.md recursive-include samples *.py recursive-include samples *.txt -recursive-include smoke *.py recursive-include test *.csv recursive-include test *.dict recursive-include test *.hyper @@ -16,5 +17,6 @@ recursive-include test *.pdf recursive-include test *.png recursive-include test *.py recursive-include test *.xml +recursive-include test *.tde global-include *.pyi -global-include *.typed \ No newline at end of file +global-include *.typed diff --git a/contributing.md b/contributing.md index c5f0fa95e..90fbdc4f0 100644 --- a/contributing.md +++ b/contributing.md @@ -57,9 +57,8 @@ somewhere. ## Getting Started ```shell -pip install versioneer -python setup.py build -python setup.py test +python -m build +pytest ``` ### To use your locally built version diff --git a/pyproject.toml b/pyproject.toml index 1884a6b37..840c062e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,37 @@ [build-system] -requires = ["setuptools>=45.0", "versioneer-518", "wheel"] +requires = ["setuptools>=45.0", "versioneer>=0.24", "wheel"] build-backend = "setuptools.build_meta" +[project] +name="tableauserverclient" + +dynamic = ["version"] +description='A Python module for working with the Tableau Server REST API.' +authors = [{name="Tableau", email="github@tableau.com"}] +license = {file = "LICENSE"} +readme = "README.md" + +dependencies = [ + 'defusedxml>=0.7.1', + 'packaging~=21.3', + 'requests>=2.28', + 'urllib3~=1.26.8', +] +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10" +] +[project.urls] +repository = "https://github.com/tableau/server-client-python" + +[project.optional-dependencies] +test = ["argparse", "black", "mock", "mypy", "pytest>=7.0", "requests-mock>=1.0,<2.0"] + [tool.black] line-length = 120 target-version = ['py37', 'py38', 'py39', 'py310'] @@ -11,8 +41,10 @@ disable_error_code = [ 'misc', 'import' ] -files = [ - "tableauserverclient", - "test" -] +files = ["tableauserverclient", "test"] show_error_codes = true +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["test"] +addopts = "--junitxml=./test.junit.xml" diff --git a/samples/create_group.py b/samples/create_group.py index 3875ffea5..50d84a187 100644 --- a/samples/create_group.py +++ b/samples/create_group.py @@ -8,10 +8,13 @@ import argparse import logging +import os from datetime import time +from typing import List import tableauserverclient as TSC +from tableauserverclient import ServerResponseError def main(): @@ -35,7 +38,7 @@ def main(): ) # Options specific to this sample # This sample has no additional options, yet. If you add some, please add them here - + parser.add_argument("--file", help="csv file containing user info", required=False) args = parser.parse_args() # Set logging level based on user input, or error by default @@ -45,9 +48,48 @@ def main(): tableau_auth = TSC.PersonalAccessTokenAuth(args.token_name, args.token_value, site_id=args.site) server = TSC.Server(args.server, use_server_version=True) with server.auth.sign_in(tableau_auth): + # this code shows 3 different error codes that mean "resource is already in collection" + # 409009: group already exists on server + # 409107: user is already on site + # 409011: user is already in group + group = TSC.GroupItem("test") - group = server.groups.create(group) - print(group) + try: + group = server.groups.create(group) + except TSC.server.endpoint.exceptions.ServerResponseError as rError: + if rError.code == "409009": + print("Group already exists") + group = server.groups.filter(name=group.name)[0] + else: + raise rError + server.groups.populate_users(group) + for user in group.users: + print(user.name) + + if args.file: + filepath = os.path.abspath(args.file) + print("Add users to site from file {}:".format(filepath)) + added: List[TSC.UserItem] + failed: List[TSC.UserItem, TSC.ServerResponseError] + added, failed = server.users.create_from_file(filepath) + for user, error in failed: + print(user, error.code) + if error.code == "409017": + user = server.users.filter(name=user.name)[0] + added.append(user) + print("Adding users to group:{}".format(added)) + for user in added: + print("Adding user {}".format(user)) + try: + server.groups.add_user(group, user.id) + except ServerResponseError as serverError: + if serverError.code == "409011": + print("user {} is already a member of group {}".format(user.name, group.name)) + else: + raise rError + + for user in group.users: + print(user.name) if __name__ == "__main__": diff --git a/samples/explore_site.py b/samples/explore_site.py new file mode 100644 index 000000000..8c4abd9d3 --- /dev/null +++ b/samples/explore_site.py @@ -0,0 +1,83 @@ +#### +# This script demonstrates how to use the Tableau Server Client +# to interact with sites. +#### + +import argparse +import logging +import os.path +import sys + +import tableauserverclient as TSC + + +def main(): + + parser = argparse.ArgumentParser(description="Explore site updates by the Server API.") + # Common options; please keep those in sync across all samples + parser.add_argument("--server", "-s", required=True, help="server address") + parser.add_argument("--site", "-S", help="site name") + parser.add_argument( + "--token-name", "-p", required=True, help="name of the personal access token used to sign into the server" + ) + parser.add_argument( + "--token-value", "-v", required=True, help="value of the personal access token used to sign into the server" + ) + parser.add_argument( + "--logging-level", + "-l", + choices=["debug", "info", "error"], + default="error", + help="desired logging level (set to error by default)", + ) + + parser.add_argument("--delete") + parser.add_argument("--create") + parser.add_argument("--url") + parser.add_argument("--new_site_name") + parser.add_argument("--user_quota") + parser.add_argument("--storage_quota") + parser.add_argument("--status") + + args = parser.parse_args() + + # Set logging level based on user input, or error by default + logging_level = getattr(logging, args.logging_level.upper()) + logging.basicConfig(level=logging_level) + + # SIGN IN + tableau_auth = TSC.PersonalAccessTokenAuth(args.token_name, args.token_value, site_id=args.site) + server = TSC.Server(args.server, use_server_version=True) + new_site = None + with server.auth.sign_in(tableau_auth): + current_site = server.sites.get_by_id(server.site_id) + + if args.delete: + print("You can only delete the site you are currently in") + print("Delete site `{}`?".format(current_site.name)) + # server.sites.delete(server.site_id) + + elif args.create: + new_site = TSC.SiteItem(args.create, args.url or args.create) + site_item = server.sites.create(new_site) + print(site_item) + # to do anything further with the site, you need to log into it + # if a PAT is required, that means going to the UI to create one + + else: + new_site = current_site + print(current_site, "current user quota:", current_site.user_quota) + print("Remember, you can only update the site you are currently in") + if args.url: + new_site.content_url = args.url + if args.user_quota: + new_site.user_quota = args.user_quota + try: + updated_site = server.sites.update(new_site) + print(updated_site, "new user quota:", updated_site.user_quota) + except TSC.ServerResponseError as e: + print(e) + + +if __name__ == "__main__": + main() diff --git a/samples/list.py b/samples/list.py index 814c1b9ca..b5cdb38a5 100644 --- a/samples/list.py +++ b/samples/list.py @@ -59,7 +59,10 @@ def main(): count = 0 for resource in TSC.Pager(endpoint.get, options): count = count + 1 - print(resource.id, resource.name) + # endpoint.populate_connections(resource) + print(resource.name[:18], " ") # , resource._connections()) + if count > 100: + break print("Total: {}".format(count)) diff --git a/samples/online_users.csv b/samples/online_users.csv new file mode 100644 index 000000000..bf4843679 --- /dev/null +++ b/samples/online_users.csv @@ -0,0 +1,2 @@ +ayoung@tableau.com, , , "Creator", None, Yes +ahsiao@tableau.com, , , "Explorer", None, No diff --git a/setup.cfg b/setup.cfg index dafb578b7..a551fdb6a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,27 +1,10 @@ -[wheel] -universal = 1 - -[pep8] -max_line_length = 120 - # See the docstring in versioneer.py for instructions. Note that you must # re-run 'versioneer.py setup' after changing this section, and commit the # resulting files. - +# versioneer does not support pyproject.toml [versioneer] VCS = git style = pep440-pre versionfile_source = tableauserverclient/_version.py versionfile_build = tableauserverclient/_version.py tag_prefix = v -#parentdir_prefix = - -[aliases] -smoke=pytest - -[tool:pytest] -testpaths = test smoke -addopts = --junitxml=./test.junit.xml - -[mypy] -ignore_missing_imports = True diff --git a/setup.py b/setup.py index 24d35250c..60d8fe6b8 100644 --- a/setup.py +++ b/setup.py @@ -1,49 +1,22 @@ -import sys import versioneer +from setuptools import setup -try: - from setuptools import setup -except ImportError: - from distutils.core import setup - -from os import path -this_directory = path.abspath(path.dirname(__file__)) -with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: - long_description = f.read() - -# Only install pytest and runner when test command is run -# This makes work easier for offline installs or low bandwidth machines -needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) -pytest_runner = ['pytest-runner'] if needs_pytest else [] -test_requirements = ['black', 'mock', 'pytest', 'requests-mock>=1.0,<2.0', 'mypy>=0.920'] - +""" +once versioneer 0.25 gets released, we can move this from setup.cfg to pyproject.toml +[tool.versioneer] +VCS = "git" +style = "pep440-pre" +versionfile_source = "tableauserverclient/_version.py" +versionfile_build = "tableauserverclient/_version.py" +tag_prefix = "v" +""" setup( - name='tableauserverclient', version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), - author='Tableau', - author_email='github@tableau.com', - url='https://github.com/tableau/server-client-python', - package_data={'tableauserverclient':['py.typed']}, + # not yet sure how to move this to pyproject.toml packages=['tableauserverclient', 'tableauserverclient.helpers', 'tableauserverclient.models', 'tableauserverclient.server', 'tableauserverclient.server.endpoint'], - license='MIT', - description='A Python module for working with the Tableau Server REST API.', - long_description=long_description, - long_description_content_type='text/markdown', - test_suite='test', - setup_requires=pytest_runner, - install_requires=[ - 'defusedxml>=0.7.1', - 'requests>=2.11,<3.0', - ], - python_requires='>3.7.0', - tests_require=test_requirements, - extras_require={ - 'test': test_requirements - }, - zip_safe=False ) diff --git a/smoke/__init__.py b/smoke/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tableauserverclient/__init__.py b/tableauserverclient/__init__.py index 592551b4e..394184120 100644 --- a/tableauserverclient/__init__.py +++ b/tableauserverclient/__init__.py @@ -1,62 +1,52 @@ -from ._version import get_versions +from .namespace import NEW_NAMESPACE as DEFAULT_NAMESPACE from .models import ( + BackgroundJobItem, + ColumnItem, ConnectionCredentials, ConnectionItem, + DQWItem, + DailyInterval, DataAlertItem, + DatabaseItem, DatasourceItem, - DQWItem, + FlowItem, + FlowRunItem, GroupItem, + HourlyInterval, + IntervalItem, JobItem, - BackgroundJobItem, + MetricItem, + MonthlyInterval, PaginationItem, + Permission, + PermissionsRule, + PersonalAccessTokenAuth, ProjectItem, + RevisionItem, ScheduleItem, SiteItem, + SubscriptionItem, + TableItem, TableauAuth, - PersonalAccessTokenAuth, + Target, + TaskItem, + UnpopulatedPropertyError, UserItem, ViewItem, - WorkbookItem, - UnpopulatedPropertyError, - HourlyInterval, - DailyInterval, - WeeklyInterval, - MonthlyInterval, - IntervalItem, - TaskItem, - SubscriptionItem, - Target, - PermissionsRule, - Permission, - DatabaseItem, - TableItem, - ColumnItem, - FlowItem, WebhookItem, - PersonalAccessTokenAuth, - FlowRunItem, - RevisionItem, - MetricItem, - TableauItem, - Resource, - plural_type, + WeeklyInterval, + WorkbookItem, ) -from .namespace import NEW_NAMESPACE as DEFAULT_NAMESPACE from .server import ( - RequestOptions, CSVRequestOptions, ImageRequestOptions, PDFRequestOptions, - Filter, - Sort, - Server, - ServerResponseError, + RequestOptions, MissingRequiredFieldError, NotSignedInError, + ServerResponseError, + Filter, Pager, + Server, + Sort, ) -from .helpers import * - -__version__ = get_versions()["version"] -__VERSION__ = __version__ -del get_versions diff --git a/tableauserverclient/datetime_helpers.py b/tableauserverclient/datetime_helpers.py index 2b1df202c..0d968428d 100644 --- a/tableauserverclient/datetime_helpers.py +++ b/tableauserverclient/datetime_helpers.py @@ -1,12 +1,12 @@ import datetime -# This code below is from the python documentation for -# tzinfo: https://docs.python.org/2.3/lib/datetime-tzinfo.html ZERO = datetime.timedelta(0) HOUR = datetime.timedelta(hours=1) +# This class is a concrete implementation of the abstract base class tzinfo +# docs: https://docs.python.org/2.3/lib/datetime-tzinfo.html class UTC(datetime.tzinfo): """UTC""" diff --git a/tableauserverclient/models/connection_item.py b/tableauserverclient/models/connection_item.py index 17ca20bb9..ed7733076 100644 --- a/tableauserverclient/models/connection_item.py +++ b/tableauserverclient/models/connection_item.py @@ -1,44 +1,48 @@ +from typing import TYPE_CHECKING, List, Optional from defusedxml.ElementTree import fromstring from .connection_credentials import ConnectionCredentials +if TYPE_CHECKING: + from tableauserverclient.models.connection_credentials import ConnectionCredentials + class ConnectionItem(object): def __init__(self): - self._datasource_id = None - self._datasource_name = None - self._id = None - self._connection_type = None - self.embed_password = None - self.password = None - self.server_address = None - self.server_port = None - self.username = None - self.connection_credentials = None + self._datasource_id: Optional[str] = None + self._datasource_name: Optional[str] = None + self._id: Optional[str] = None + self._connection_type: Optional[str] = None + self.embed_password: bool = None + self.password: Optional[str] = None + self.server_address: Optional[str] = None + self.server_port: Optional[str] = None + self.username: Optional[str] = None + self.connection_credentials: Optional["ConnectionCredentials"] = None @property - def datasource_id(self): + def datasource_id(self) -> Optional[str]: return self._datasource_id @property - def datasource_name(self): + def datasource_name(self) -> Optional[str]: return self._datasource_name @property - def id(self): + def id(self) -> Optional[str]: return self._id @property - def connection_type(self): + def connection_type(self) -> Optional[str]: return self._connection_type def __repr__(self): - return "".format( + return "".format( **self.__dict__ ) @classmethod - def from_response(cls, resp, ns): + def from_response(cls, resp, ns) -> List["ConnectionItem"]: all_connection_items = list() parsed_response = fromstring(resp) all_connection_xml = parsed_response.findall(".//t:connection", namespaces=ns) @@ -58,7 +62,7 @@ def from_response(cls, resp, ns): return all_connection_items @classmethod - def from_xml_element(cls, parsed_response, ns): + def from_xml_element(cls, parsed_response, ns) -> List["ConnectionItem"]: """ @@ -69,7 +73,7 @@ def from_xml_element(cls, parsed_response, ns): """ - all_connection_items = list() + all_connection_items: List["ConnectionItem"] = list() all_connection_xml = parsed_response.findall(".//t:connection", namespaces=ns) for connection_xml in all_connection_xml: @@ -82,11 +86,13 @@ def from_xml_element(cls, parsed_response, ns): if connection_credentials is not None: - connection_item.connection_credentials = ConnectionCredentials.from_xml_element(connection_credentials) + connection_item.connection_credentials = ConnectionCredentials.from_xml_element( + connection_credentials, ns + ) return all_connection_items # Used to convert string represented boolean to a boolean type -def string_to_bool(s): +def string_to_bool(s: str) -> bool: return s.lower() == "true" diff --git a/tableauserverclient/models/group_item.py b/tableauserverclient/models/group_item.py index 6fcf18544..eb03b1b5d 100644 --- a/tableauserverclient/models/group_item.py +++ b/tableauserverclient/models/group_item.py @@ -27,6 +27,11 @@ def __init__(self, name=None, domain_name=None) -> None: self.name: Optional[str] = name self.domain_name: Optional[str] = domain_name + def __str__(self): + return "{}({!r})".format(self.__class__.__name__, self.__dict__) + + __repr__ = __str__ + @property def domain_name(self) -> Optional[str]: return self._domain_name @@ -74,9 +79,6 @@ def users(self) -> "Pager": # Each call to `.users` should create a new pager, this just runs the callable return self._users() - def to_reference(self) -> ResourceReference: - return ResourceReference(id_=self.id, tag_name=self.tag_name) - def _set_users(self, users: Callable[..., "Pager"]) -> None: self._users = users diff --git a/tableauserverclient/models/project_item.py b/tableauserverclient/models/project_item.py index 9237d134e..acb14ce91 100644 --- a/tableauserverclient/models/project_item.py +++ b/tableauserverclient/models/project_item.py @@ -9,9 +9,6 @@ from typing import List, Optional -from typing import List, Optional, TYPE_CHECKING - - class ProjectItem(object): class ContentPermissions: LockedToProject: str = "LockedToProject" diff --git a/tableauserverclient/models/site_item.py b/tableauserverclient/models/site_item.py index 2d27acabf..3deda03e2 100644 --- a/tableauserverclient/models/site_item.py +++ b/tableauserverclient/models/site_item.py @@ -1,8 +1,8 @@ import warnings import xml.etree.ElementTree as ET +from distutils.version import Version from defusedxml.ElementTree import fromstring - from .property_decorators import ( property_is_enum, property_is_boolean, @@ -14,7 +14,10 @@ VALID_CONTENT_URL_RE = r"^[a-zA-Z0-9_\-]*$" -from typing import List, Optional, Union +from typing import List, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from tableauserverclient.server import Server class SiteItem(object): @@ -23,6 +26,19 @@ class SiteItem(object): _tier_explorer_capacity: Optional[int] = None _tier_viewer_capacity: Optional[int] = None + def __str__(self): + return ( + "<" + + __name__ + + ": " + + (self.name or "unnamed") + + ", " + + (self.id or "unknown-id") + + ", " + + (self.state or "unknown-state") + + ">" + ) + class AdminMode: ContentAndUsers: str = "ContentAndUsers" ContentOnly: str = "ContentOnly" @@ -261,6 +277,13 @@ def cataloging_enabled(self) -> bool: def cataloging_enabled(self, value: bool): self._cataloging_enabled = value + def is_default(self) -> bool: + return self.name.lower() == "default" + + @staticmethod + def use_new_flow_settings(parent_srv: "Server") -> bool: + return parent_srv is not None and parent_srv.check_at_least_version("3.10") + @property def flows_enabled(self) -> bool: return self._flows_enabled @@ -268,11 +291,10 @@ def flows_enabled(self) -> bool: @flows_enabled.setter @property_is_boolean def flows_enabled(self, value: bool) -> None: + # Flows Enabled' is not a supported site setting in API Version [3.17]. + # In Version 3.10+ use the more granular settings 'Editing Flows Enabled' and/or 'Scheduling Flows Enabled' self._flows_enabled = value - def is_default(self) -> bool: - return self.name.lower() == "default" - @property def editing_flows_enabled(self) -> bool: return self._editing_flows_enabled diff --git a/tableauserverclient/models/user_item.py b/tableauserverclient/models/user_item.py index f60e72951..032841dc7 100644 --- a/tableauserverclient/models/user_item.py +++ b/tableauserverclient/models/user_item.py @@ -1,7 +1,8 @@ -from datetime import datetime +import io +import logging import xml.etree.ElementTree as ET from datetime import datetime -from typing import Dict, List, Optional, TYPE_CHECKING +from enum import IntEnum from defusedxml.ElementTree import fromstring @@ -9,15 +10,11 @@ from .property_decorators import ( property_is_enum, property_not_empty, - property_not_nullable, ) from .reference_item import ResourceReference from ..datetime_helpers import parse_datetime -if TYPE_CHECKING: - from ..server.pager import Pager - -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING, Tuple if TYPE_CHECKING: from ..server.pager import Pager @@ -72,6 +69,10 @@ def __init__( return None + def __repr__(self) -> str: + str_site_role = self.site_role or "None" + return "".format(self.id, self.name, str_site_role) + @property def auth_setting(self) -> Optional[str]: return self._auth_setting @@ -106,12 +107,24 @@ def name(self) -> Optional[str]: def name(self, value: str): self._name = value + # valid: username, domain/username, username@domain, domain/username@email + @staticmethod + def validate_username_or_throw(username) -> None: + if username is None or username == "" or username.strip(" ") == "": + raise AttributeError("Username cannot be empty") + if username.find(" ") >= 0: + raise AttributeError("Username cannot contain spaces") + at_symbol = username.find("@") + if at_symbol >= 0: + username = username[:at_symbol] + "X" + username[at_symbol + 1 :] + if username.find("@") >= 0: + raise AttributeError("Username cannot repeat '@'") + @property def site_role(self) -> Optional[str]: return self._site_role @site_role.setter - @property_not_nullable @property_is_enum(Roles) def site_role(self, value): self._site_role = value @@ -137,9 +150,6 @@ def groups(self) -> "Pager": raise UnpopulatedPropertyError(error) return self._groups() - def to_reference(self) -> ResourceReference: - return ResourceReference(id_=self.id, tag_name=self.tag_name) - def _set_workbooks(self, workbooks) -> None: self._workbooks = workbooks @@ -259,5 +269,149 @@ def _parse_element(user_xml, ns): domain_name, ) - def __repr__(self) -> str: - return "".format(self.id, self.name, self.site_role) + class CSVImport(object): + """ + This class includes hardcoded options and logic for the CSV file format defined for user import + https://help.tableau.com/current/server/en-us/users_import.htm + """ + + # username, password, display_name, license, admin_level, publishing, email, auth type + class ColumnType(IntEnum): + USERNAME = 0 + PASS = 1 + DISPLAY_NAME = 2 + LICENSE = 3 # aka site role + ADMIN = 4 + PUBLISHER = 5 + EMAIL = 6 + AUTH = 7 + + MAX = 7 + + # Read a csv line and create a user item populated by the given attributes + @staticmethod + def create_user_from_line(line: str): + if line is None or line is False or line == "\n" or line == "": + return None + line = line.strip().lower() + values: List[str] = list(map(str.strip, line.split(","))) + user = UserItem(values[UserItem.CSVImport.ColumnType.USERNAME]) + if len(values) > 1: + if len(values) > UserItem.CSVImport.ColumnType.MAX: + raise ValueError("Too many attributes for user import") + while len(values) <= UserItem.CSVImport.ColumnType.MAX: + values.append("") + site_role = UserItem.CSVImport._evaluate_site_role( + values[UserItem.CSVImport.ColumnType.LICENSE], + values[UserItem.CSVImport.ColumnType.ADMIN], + values[UserItem.CSVImport.ColumnType.PUBLISHER], + ) + + user._set_values( + None, + values[UserItem.CSVImport.ColumnType.USERNAME], + site_role, + None, + None, + values[UserItem.CSVImport.ColumnType.DISPLAY_NAME], + values[UserItem.CSVImport.ColumnType.EMAIL], + values[UserItem.CSVImport.ColumnType.AUTH], + None, + ) + return user + + # Read through an entire CSV file meant for user import + # Return the number of valid lines and a list of all the invalid lines + @staticmethod + def validate_file_for_import(csv_file: io.TextIOWrapper, logger) -> Tuple[int, List[str]]: + num_valid_lines = 0 + invalid_lines = [] + csv_file.seek(0) # set to start of file in case it has been read earlier + line: str = csv_file.readline() + while line and line != "": + try: + # do not print passwords + logger.info("Reading user {}".format(line[:4])) + UserItem.CSVImport._validate_import_line_or_throw(line, logger) + num_valid_lines += 1 + except Exception as exc: + logger.info("Error parsing {}: {}".format(line[:4], exc)) + invalid_lines.append(line) + line = csv_file.readline() + return num_valid_lines, invalid_lines + + # Some fields in the import file are restricted to specific values + # Iterate through each field and validate the given value against hardcoded constraints + @staticmethod + def _validate_import_line_or_throw(incoming, logger) -> None: + _valid_attributes: List[List[str]] = [ + [], + [], + [], + ["creator", "explorer", "viewer", "unlicensed"], # license + ["system", "site", "none", "no"], # admin + ["yes", "true", "1", "no", "false", "0"], # publisher + [], + [UserItem.Auth.SAML, UserItem.Auth.OpenID, UserItem.Auth.ServerDefault], # auth + ] + + line = list(map(str.strip, incoming.split(","))) + if len(line) > UserItem.CSVImport.ColumnType.MAX: + raise AttributeError("Too many attributes in line") + username = line[UserItem.CSVImport.ColumnType.USERNAME.value] + logger.debug("> details - {}".format(username)) + UserItem.validate_username_or_throw(username) + for i in range(1, len(line)): + logger.debug("column {}: {}".format(UserItem.CSVImport.ColumnType(i).name, line[i])) + UserItem.CSVImport._validate_attribute_value( + line[i], _valid_attributes[i], UserItem.CSVImport.ColumnType(i) + ) + + # Given a restricted set of possible values, confirm the item is in that set + @staticmethod + def _validate_attribute_value(item: str, possible_values: List[str], column_type) -> None: + if item is None or item == "": + # value can be empty for any column except user, which is checked elsewhere + return + if item in possible_values or possible_values == []: + return + raise AttributeError("Invalid value {} for {}".format(item, column_type)) + + # https://help.tableau.com/current/server/en-us/csvguidelines.htm#settings_and_site_roles + # This logic is hardcoded to match the existing rules for import csv files + @staticmethod + def _evaluate_site_role(license_level, admin_level, publisher): + if not license_level or not admin_level or not publisher: + return "Unlicensed" + # ignore case everywhere + license_level = license_level.lower() + admin_level = admin_level.lower() + publisher = publisher.lower() + # don't need to check publisher for system/site admin + if admin_level == "system": + site_role = "SiteAdministrator" + elif admin_level == "site": + if license_level == "creator": + site_role = "SiteAdministratorCreator" + elif license_level == "explorer": + site_role = "SiteAdministratorExplorer" + else: + site_role = "SiteAdministratorExplorer" + else: # if it wasn't 'system' or 'site' then we can treat it as 'none' + if publisher == "yes": + if license_level == "creator": + site_role = "Creator" + elif license_level == "explorer": + site_role = "ExplorerCanPublish" + else: + site_role = "Unlicensed" # is this the expected outcome? + else: # publisher == 'no': + if license_level == "explorer" or license_level == "creator": + site_role = "Explorer" + elif license_level == "viewer": + site_role = "Viewer" + else: # if license_level == 'unlicensed' + site_role = "Unlicensed" + if site_role is None: + site_role = "Unlicensed" + return site_role diff --git a/tableauserverclient/server/__init__.py b/tableauserverclient/server/__init__.py index cb680d914..25abb3c9a 100644 --- a/tableauserverclient/server/__init__.py +++ b/tableauserverclient/server/__init__.py @@ -9,7 +9,7 @@ from .filter import Filter from .sort import Sort -from .. import ( +from ..models import ( BackgroundJobItem, ColumnItem, ConnectionItem, diff --git a/tableauserverclient/server/endpoint/auth_endpoint.py b/tableauserverclient/server/endpoint/auth_endpoint.py index 11e89975a..6baf399ed 100644 --- a/tableauserverclient/server/endpoint/auth_endpoint.py +++ b/tableauserverclient/server/endpoint/auth_endpoint.py @@ -30,7 +30,7 @@ def sign_in(self, auth_req): signin_req = RequestFactory.Auth.signin_req(auth_req) server_response = self.parent_srv.session.post(url, data=signin_req, **self.parent_srv.http_options) self.parent_srv._namespace.detect(server_response.content) - self._check_status(server_response) + self._check_status(server_response, url) parsed_response = fromstring(server_response.content) site_id = parsed_response.find(".//t:site", namespaces=self.parent_srv.namespace).get("id", None) user_id = parsed_response.find(".//t:user", namespaces=self.parent_srv.namespace).get("id", None) @@ -66,7 +66,7 @@ def switch_site(self, site_item): else: raise e self.parent_srv._namespace.detect(server_response.content) - self._check_status(server_response) + self._check_status(server_response, url) parsed_response = fromstring(server_response.content) site_id = parsed_response.find(".//t:site", namespaces=self.parent_srv.namespace).get("id", None) user_id = parsed_response.find(".//t:user", namespaces=self.parent_srv.namespace).get("id", None) diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 0acc978d2..378c84746 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -11,6 +11,7 @@ NonXMLResponseError, EndpointUnavailableError, ) +from .. import endpoint from ..query import QuerySet from ... import helpers @@ -26,18 +27,29 @@ from requests import Response +_version_header: Optional[str] = None + + class Endpoint(object): def __init__(self, parent_srv: "Server"): + global _version_header self.parent_srv = parent_srv @staticmethod def _make_common_headers(auth_token, content_type): + global _version_header + + if not _version_header: + from ..server import __TSC_VERSION__ + + _version_header = __TSC_VERSION__ + headers = {} if auth_token is not None: headers["x-tableau-auth"] = auth_token if content_type is not None: headers["content-type"] = content_type - + headers["User-Agent"] = "Tableau Server Client/{}".format(_version_header) return headers def _make_request( @@ -63,7 +75,7 @@ def _make_request( logger.debug("request content: {}".format(helpers.strings.redact_xml(content[:1000]))) server_response = method(url, **parameters) - self._check_status(server_response) + self._check_status(server_response, url) loggable_response = self.log_response_safely(server_response) logger.debug("Server response from {0}:\n\t{1}".format(url, loggable_response)) @@ -73,13 +85,13 @@ def _make_request( return server_response - def _check_status(self, server_response): + def _check_status(self, server_response, url: str = None): if server_response.status_code >= 500: - raise InternalServerError(server_response) + raise InternalServerError(server_response, url) elif server_response.status_code not in Success_codes: # todo: is an error reliably of content-type application/xml? try: - raise ServerResponseError.from_response(server_response.content, self.parent_srv.namespace) + raise ServerResponseError.from_response(server_response.content, self.parent_srv.namespace, url) except ParseError: # This will happen if we get a non-success HTTP code that # doesn't return an xml error object (like metadata endpoints or 503 pages) @@ -112,7 +124,7 @@ def get_request(self, url, request_object=None, parameters=None): if request_object is not None: try: # Query param delimiters don't need to be encoded for versions before 3.7 (2020.1) - self.parent_srv.assert_at_least_version("3.7") + self.parent_srv.assert_at_least_version("3.7", "Query param encoding") parameters = parameters or {} parameters["params"] = request_object.get_query_params() except EndpointUnavailableError: @@ -126,7 +138,7 @@ def get_request(self, url, request_object=None, parameters=None): ) def delete_request(self, url): - # We don't return anything for a delete + # We don't return anything for a delete request self._make_request(self.parent_srv.session.delete, url, auth_token=self.parent_srv.auth_token) def put_request(self, url, xml_request=None, content_type=XML_CONTENT_TYPE, parameters=None): @@ -182,7 +194,7 @@ def api(version): def _decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): - self.parent_srv.assert_at_least_version(version) + self.parent_srv.assert_at_least_version(version, "endpoint") return func(self, *args, **kwargs) return wrapper diff --git a/tableauserverclient/server/endpoint/exceptions.py b/tableauserverclient/server/endpoint/exceptions.py index 34de00dd0..3ce0d5e92 100644 --- a/tableauserverclient/server/endpoint/exceptions.py +++ b/tableauserverclient/server/endpoint/exceptions.py @@ -1,62 +1,72 @@ from defusedxml.ElementTree import fromstring -class ServerResponseError(Exception): - def __init__(self, code, summary, detail): +class TableauError(Exception): + pass + + +class ServerResponseError(TableauError): + def __init__(self, code, summary, detail, url=None): self.code = code self.summary = summary self.detail = detail + self.url = url super(ServerResponseError, self).__init__(str(self)) def __str__(self): return "\n\n\t{0}: {1}\n\t\t{2}".format(self.code, self.summary, self.detail) @classmethod - def from_response(cls, resp, ns): + def from_response(cls, resp, ns, url=None): # Check elements exist before .text parsed_response = fromstring(resp) - error_response = cls( - parsed_response.find("t:error", namespaces=ns).get("code", ""), - parsed_response.find(".//t:summary", namespaces=ns).text, - parsed_response.find(".//t:detail", namespaces=ns).text, - ) + try: + error_response = cls( + parsed_response.find("t:error", namespaces=ns).get("code", ""), + parsed_response.find(".//t:summary", namespaces=ns).text, + parsed_response.find(".//t:detail", namespaces=ns).text, + url, + ) + except Exception as e: + raise NonXMLResponseError(resp) return error_response -class InternalServerError(Exception): - def __init__(self, server_response): +class InternalServerError(TableauError): + def __init__(self, server_response, request_url: str = None): self.code = server_response.status_code self.content = server_response.content + self.url = request_url or "server" def __str__(self): - return "\n\nError status code: {0}\n{1}".format(self.code, self.content) + return "\n\nInternal error {0} at {1}\n{2}".format(self.code, self.url, self.content) -class MissingRequiredFieldError(Exception): +class MissingRequiredFieldError(TableauError): pass -class ServerInfoEndpointNotFoundError(Exception): +class ServerInfoEndpointNotFoundError(TableauError): pass -class EndpointUnavailableError(Exception): +class EndpointUnavailableError(TableauError): pass -class ItemTypeNotAllowed(Exception): +class ItemTypeNotAllowed(TableauError): pass -class NonXMLResponseError(Exception): +class NonXMLResponseError(TableauError): pass -class InvalidGraphQLQuery(Exception): +class InvalidGraphQLQuery(TableauError): pass -class GraphQLError(Exception): +class GraphQLError(TableauError): def __init__(self, error_payload): self.error = error_payload @@ -66,7 +76,7 @@ def __str__(self): return pformat(self.error) -class JobFailedException(Exception): +class JobFailedException(TableauError): def __init__(self, job): self.notes = job.notes self.job = job @@ -79,7 +89,7 @@ class JobCancelledException(JobFailedException): pass -class FlowRunFailedException(Exception): +class FlowRunFailedException(TableauError): def __init__(self, flow_run): self.background_job_id = flow_run.background_job_id self.flow_run = flow_run diff --git a/tableauserverclient/server/endpoint/jobs_endpoint.py b/tableauserverclient/server/endpoint/jobs_endpoint.py index 99870ac34..6b709efad 100644 --- a/tableauserverclient/server/endpoint/jobs_endpoint.py +++ b/tableauserverclient/server/endpoint/jobs_endpoint.py @@ -29,7 +29,7 @@ def get( if isinstance(job_id, RequestOptionsBase): req_options = job_id - self.parent_srv.assert_at_least_version("3.1") + self.parent_srv.assert_at_least_version("3.1", "Jobs.get_by_id(job_id)") server_response = self.get_request(self.baseurl, req_options) pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) jobs = BackgroundJobItem.from_response(server_response.content, self.parent_srv.namespace) diff --git a/tableauserverclient/server/endpoint/server_info_endpoint.py b/tableauserverclient/server/endpoint/server_info_endpoint.py index 5c9461d1c..2036d8d5e 100644 --- a/tableauserverclient/server/endpoint/server_info_endpoint.py +++ b/tableauserverclient/server/endpoint/server_info_endpoint.py @@ -26,6 +26,7 @@ def get(self): raise ServerInfoEndpointNotFoundError if e.code == "404001": raise EndpointUnavailableError + raise e server_info = ServerInfoItem.from_response(server_response.content, self.parent_srv.namespace) return server_info diff --git a/tableauserverclient/server/endpoint/sites_endpoint.py b/tableauserverclient/server/endpoint/sites_endpoint.py index bdf281fb9..67d7db209 100644 --- a/tableauserverclient/server/endpoint/sites_endpoint.py +++ b/tableauserverclient/server/endpoint/sites_endpoint.py @@ -22,6 +22,7 @@ def baseurl(self) -> str: @api(version="2.0") def get(self, req_options: Optional["RequestOptions"] = None) -> Tuple[List[SiteItem], PaginationItem]: logger.info("Querying all sites on site") + logger.info("Requires Server Admin permissions") url = self.baseurl server_response = self.get_request(url, req_options) pagination_item = PaginationItem.from_response(server_response.content, self.parent_srv.namespace) @@ -34,6 +35,10 @@ def get_by_id(self, site_id: str) -> SiteItem: if not site_id: error = "Site ID undefined." raise ValueError(error) + if not site_id == self.parent_srv.site_id: + error = "You can only retrieve the site for which you are currently authenticated." + raise ValueError(error) + logger.info("Querying single site (ID: {0})".format(site_id)) url = "{0}/{1}".format(self.baseurl, site_id) server_response = self.get_request(url) @@ -45,8 +50,10 @@ def get_by_name(self, site_name: str) -> SiteItem: if not site_name: error = "Site Name undefined." raise ValueError(error) + print("Note: You can only work with the site for which you are currently authenticated") logger.info("Querying single site (Name: {0})".format(site_name)) url = "{0}/{1}?key=name".format(self.baseurl, site_name) + print(self.baseurl, url) server_response = self.get_request(url) return SiteItem.from_response(server_response.content, self.parent_srv.namespace)[0] @@ -56,7 +63,12 @@ def get_by_content_url(self, content_url: str) -> SiteItem: if content_url is None: error = "Content URL undefined." raise ValueError(error) + if not self.parent_srv.baseurl.index(content_url) > 0: + error = "You can only work with the site you are currently authenticated for" + raise ValueError(error) + logger.info("Querying single site (Content URL: {0})".format(content_url)) + logger.debug("Querying other sites requires Server Admin permissions") url = "{0}/{1}?key=contentUrl".format(self.baseurl, content_url) server_response = self.get_request(url) return SiteItem.from_response(server_response.content, self.parent_srv.namespace)[0] @@ -67,13 +79,18 @@ def update(self, site_item: SiteItem) -> SiteItem: if not site_item.id: error = "Site item missing ID." raise MissingRequiredFieldError(error) + print(self.parent_srv.site_id, site_item.id) + if not site_item.id == self.parent_srv.site_id: + error = "You can only update the site you are currently authenticated for" + raise ValueError(error) + if site_item.admin_mode: if site_item.admin_mode == SiteItem.AdminMode.ContentOnly and site_item.user_quota: error = "You cannot set admin_mode to ContentOnly and also set a user quota" raise ValueError(error) url = "{0}/{1}".format(self.baseurl, site_item.id) - update_req = RequestFactory.Site.update_req(site_item) + update_req = RequestFactory.Site.update_req(site_item, self.parent_srv) server_response = self.put_request(url, update_req) logger.info("Updated site item (ID: {0})".format(site_item.id)) update_site = copy.copy(site_item) @@ -86,12 +103,11 @@ def delete(self, site_id: str) -> None: error = "Site ID undefined." raise ValueError(error) url = "{0}/{1}".format(self.baseurl, site_id) + if not site_id == self.parent_srv.site_id: + error = "You can only delete the site you are currently authenticated for" + raise ValueError(error) self.delete_request(url) - # If we deleted the site we are logged into - # then we are automatically logged out - if site_id == self.parent_srv.site_id: - logger.info("Deleting current site and clearing auth tokens") - self.parent_srv._clear_auth() + self.parent_srv._clear_auth() logger.info("Deleted single site (ID: {0}) and signed out".format(site_id)) # Create new site @@ -103,7 +119,7 @@ def create(self, site_item: SiteItem) -> SiteItem: raise ValueError(error) url = self.baseurl - create_req = RequestFactory.Site.create_req(site_item) + create_req = RequestFactory.Site.create_req(site_item, self.parent_srv) server_response = self.post_request(url, create_req) new_site = SiteItem.from_response(server_response.content, self.parent_srv.namespace)[0] logger.info("Created new site (ID: {0})".format(new_site.id)) diff --git a/tableauserverclient/server/endpoint/tasks_endpoint.py b/tableauserverclient/server/endpoint/tasks_endpoint.py index f147c79ae..a70480b91 100644 --- a/tableauserverclient/server/endpoint/tasks_endpoint.py +++ b/tableauserverclient/server/endpoint/tasks_endpoint.py @@ -25,7 +25,7 @@ def __normalize_task_type(self, task_type): @api(version="2.6") def get(self, req_options=None, task_type=TaskItem.Type.ExtractRefresh): if task_type == TaskItem.Type.DataAcceleration: - self.parent_srv.assert_at_least_version("3.8") + self.parent_srv.assert_at_least_version("3.8", "Data Acceleration Tasks") logger.info("Querying all {} tasks for the site".format(task_type)) @@ -69,7 +69,7 @@ def run(self, task_item): @api(version="3.6") def delete(self, task_id, task_type=TaskItem.Type.ExtractRefresh): if task_type == TaskItem.Type.DataAcceleration: - self.parent_srv.assert_at_least_version("3.8") + self.parent_srv.assert_at_least_version("3.8", "Data Acceleration Tasks") if not task_id: error = "No Task ID provided" diff --git a/tableauserverclient/server/endpoint/users_endpoint.py b/tableauserverclient/server/endpoint/users_endpoint.py index 738364cd7..28406ab71 100644 --- a/tableauserverclient/server/endpoint/users_endpoint.py +++ b/tableauserverclient/server/endpoint/users_endpoint.py @@ -1,19 +1,16 @@ import copy import logging -from typing import List, Optional, Tuple +import os +from typing import List, Optional, Tuple, Union from .endpoint import QuerysetEndpoint, api -from .exceptions import MissingRequiredFieldError -from .. import ( - RequestFactory, - RequestOptions, - UserItem, - WorkbookItem, - PaginationItem, - GroupItem, -) +from .exceptions import MissingRequiredFieldError, ServerResponseError +from .. import RequestFactory, RequestOptions, UserItem, WorkbookItem, PaginationItem, GroupItem from ..pager import Pager +# duplicate defined in workbooks_endpoint +FilePath = Union[str, os.PathLike] + logger = logging.getLogger("tableau.endpoint.users") @@ -78,12 +75,51 @@ def remove(self, user_id: str, map_assets_to: Optional[str] = None) -> None: @api(version="2.0") def add(self, user_item: UserItem) -> UserItem: url = self.baseurl + logger.info("Add user {}".format(user_item.name)) add_req = RequestFactory.User.add_req(user_item) server_response = self.post_request(url, add_req) + logger.info(server_response) new_user = UserItem.from_response(server_response.content, self.parent_srv.namespace).pop() logger.info("Added new user (ID: {0})".format(new_user.id)) return new_user + # Add new users to site. This does not actually perform a bulk action, it's syntactic sugar + @api(version="2.0") + def add_all(self, users: List[UserItem]): + created = [] + failed = [] + for user in users: + try: + result = self.add(user) + created.append(result) + except Exception as e: + failed.append(user) + return created, failed + + # helping the user by parsing a file they could have used to add users through the UI + # line format: Username [required], password, display name, license, admin, publish + @api(version="2.0") + def create_from_file(self, filepath: str) -> Tuple[List[UserItem], List[Tuple[UserItem, ServerResponseError]]]: + created = [] + failed = [] + if not filepath.find("csv"): + raise ValueError("Only csv files are accepted") + + with open(filepath) as csv_file: + csv_file.seek(0) # set to start of file in case it has been read earlier + line: str = csv_file.readline() + while line and line != "": + user: UserItem = UserItem.CSVImport.create_user_from_line(line) + try: + print(user) + result = self.add(user) + created.append(result) + except ServerResponseError as serverError: + print("failed") + failed.append((user, serverError)) + line = csv_file.readline() + return created, failed + # Get workbooks for user @api(version="2.0") def populate_workbooks(self, user_item: UserItem, req_options: RequestOptions = None) -> None: diff --git a/tableauserverclient/server/endpoint/views_endpoint.py b/tableauserverclient/server/endpoint/views_endpoint.py index 67e66a81f..06cc08349 100644 --- a/tableauserverclient/server/endpoint/views_endpoint.py +++ b/tableauserverclient/server/endpoint/views_endpoint.py @@ -12,7 +12,13 @@ from typing import Iterator, List, Optional, Tuple, TYPE_CHECKING if TYPE_CHECKING: - from ..request_options import RequestOptions, CSVRequestOptions, PDFRequestOptions, ImageRequestOptions + from ..request_options import ( + RequestOptions, + CSVRequestOptions, + PDFRequestOptions, + ImageRequestOptions, + ExcelRequestOptions, + ) class Views(QuerysetEndpoint): @@ -126,7 +132,7 @@ def _get_view_csv(self, view_item: ViewItem, req_options: Optional["CSVRequestOp yield from server_response.iter_content(1024) @api(version="3.8") - def populate_excel(self, view_item: ViewItem, req_options: Optional["CSVRequestOptions"] = None) -> None: + def populate_excel(self, view_item: ViewItem, req_options: Optional["ExcelRequestOptions"] = None) -> None: if not view_item.id: error = "View item missing ID." raise MissingRequiredFieldError(error) @@ -137,7 +143,7 @@ def excel_fetcher(): view_item._set_excel(excel_fetcher) logger.info("Populated excel for view (ID: {0})".format(view_item.id)) - def _get_view_excel(self, view_item: ViewItem, req_options: Optional["CSVRequestOptions"]) -> Iterator[bytes]: + def _get_view_excel(self, view_item: ViewItem, req_options: Optional["ExcelRequestOptions"]) -> Iterator[bytes]: url = "{0}/{1}/crosstab/excel".format(self.baseurl, view_item.id) with closing(self.get_request(url, request_object=req_options, parameters={"stream": True})) as server_response: diff --git a/tableauserverclient/server/query.py b/tableauserverclient/server/query.py index 729447822..c5613b2d6 100644 --- a/tableauserverclient/server/query.py +++ b/tableauserverclient/server/query.py @@ -1,13 +1,21 @@ +from typing import Tuple from .filter import Filter from .request_options import RequestOptions from .sort import Sort import math -def to_camel_case(word): +def to_camel_case(word: str) -> str: return word.split("_")[0] + "".join(x.capitalize() or "_" for x in word.split("_")[1:]) +""" +This interface allows more fluent queries against Tableau Server +e.g server.users.get(name="user@domain.com") +see pagination_sample +""" + + class QuerySet: def __init__(self, model): self.model = model @@ -85,18 +93,21 @@ def _fetch_all(self): if self._result_cache is None: self._result_cache, self._pagination_item = self.model.get(self.request_options) + def __len__(self) -> int: + return self.total_available + @property - def total_available(self): + def total_available(self) -> int: self._fetch_all() return self._pagination_item.total_available @property - def page_number(self): + def page_number(self) -> int: self._fetch_all() return self._pagination_item.page_number @property - def page_size(self): + def page_size(self) -> int: self._fetch_all() return self._pagination_item.page_size @@ -121,7 +132,7 @@ def paginate(self, **kwargs): self.request_options.pagesize = kwargs["page_size"] return self - def _parse_shorthand_filter(self, key): + def _parse_shorthand_filter(self, key: str) -> Tuple[str, str]: tokens = key.split("__", 1) if len(tokens) == 1: operator = RequestOptions.Operator.Equals @@ -135,7 +146,7 @@ def _parse_shorthand_filter(self, key): raise ValueError("Field name `{}` is not valid.".format(field)) return (field, operator) - def _parse_shorthand_sort(self, key): + def _parse_shorthand_sort(self, key: str) -> Tuple[str, str]: direction = RequestOptions.Direction.Asc if key.startswith("-"): direction = RequestOptions.Direction.Desc diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index fc00ca085..aad8ca074 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1,6 +1,6 @@ from os import name import xml.etree.ElementTree as ET -from typing import Any, Dict, List, Optional, Tuple, Iterable +from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING from requests.packages.urllib3.fields import RequestField from requests.packages.urllib3.filepost import encode_multipart_formdata @@ -16,8 +16,6 @@ from ..models import TaskItem, UserItem, GroupItem, PermissionsRule, FavoriteItem from ..models import WebhookItem -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Iterable - if TYPE_CHECKING: from ..models import SubscriptionItem from ..models import DataAlertItem @@ -25,6 +23,7 @@ from ..models import ConnectionItem from ..models import SiteItem from ..models import ProjectItem + from tableauserverclient.server import Server def _add_multipart(parts: Dict) -> Tuple[Any, str]: @@ -39,7 +38,7 @@ def _add_multipart(parts: Dict) -> Tuple[Any, str]: def _tsrequest_wrapped(func): - def wrapper(self, *args, **kwargs): + def wrapper(self, *args, **kwargs) -> bytes: xml_request = ET.Element("tsRequest") func(self, xml_request, *args, **kwargs) return ET.tostring(xml_request) @@ -556,7 +555,7 @@ def _add_to_req(self, id_: Optional[str], target_type: str, task_type: str = Tas """ if not isinstance(id_, str): - raise ValueError(f"id_ should be a string, reeceived: {type(id_)}") + raise ValueError(f"id_ should be a string, received: {type(id_)}") xml_request = ET.Element("tsRequest") task_element = ET.SubElement(xml_request, "task") task = ET.SubElement(task_element, task_type) @@ -576,7 +575,7 @@ def add_flow_req(self, id_: Optional[str], task_type: str = TaskItem.Type.RunFlo class SiteRequest(object): - def update_req(self, site_item: "SiteItem"): + def update_req(self, site_item: "SiteItem", parent_srv: "Server" = None): xml_request = ET.Element("tsRequest") site_element = ET.SubElement(xml_request, "site") if site_item.name: @@ -601,14 +600,15 @@ def update_req(self, site_item: "SiteItem"): site_element.attrib["revisionHistoryEnabled"] = str(site_item.revision_history_enabled).lower() if site_item.data_acceleration_mode is not None: site_element.attrib["dataAccelerationMode"] = str(site_item.data_acceleration_mode).lower() - if site_item.flows_enabled is not None: - site_element.attrib["flowsEnabled"] = str(site_item.flows_enabled).lower() if site_item.cataloging_enabled is not None: site_element.attrib["catalogingEnabled"] = str(site_item.cataloging_enabled).lower() - if site_item.editing_flows_enabled is not None: - site_element.attrib["editingFlowsEnabled"] = str(site_item.editing_flows_enabled).lower() - if site_item.scheduling_flows_enabled is not None: - site_element.attrib["schedulingFlowsEnabled"] = str(site_item.scheduling_flows_enabled).lower() + + flows_edit = str(site_item.editing_flows_enabled).lower() + flows_schedule = str(site_item.scheduling_flows_enabled).lower() + flows_all = str(site_item.flows_enabled).lower() + + self.set_versioned_flow_attributes(flows_all, flows_edit, flows_schedule, parent_srv, site_element, site_item) + if site_item.allow_subscription_attachments is not None: site_element.attrib["allowSubscriptionAttachments"] = str(site_item.allow_subscription_attachments).lower() if site_item.guest_access_enabled is not None: @@ -682,7 +682,8 @@ def update_req(self, site_item: "SiteItem"): return ET.tostring(xml_request) - def create_req(self, site_item: "SiteItem"): + # server: the site request model changes based on api version + def create_req(self, site_item: "SiteItem", parent_srv: "Server" = None): xml_request = ET.Element("tsRequest") site_element = ET.SubElement(xml_request, "site") site_element.attrib["name"] = site_item.name @@ -701,12 +702,13 @@ def create_req(self, site_item: "SiteItem"): site_element.attrib["revisionLimit"] = str(site_item.revision_limit) if site_item.data_acceleration_mode is not None: site_element.attrib["dataAccelerationMode"] = str(site_item.data_acceleration_mode).lower() - if site_item.flows_enabled is not None: - site_element.attrib["flowsEnabled"] = str(site_item.flows_enabled).lower() - if site_item.editing_flows_enabled is not None: - site_element.attrib["editingFlowsEnabled"] = str(site_item.editing_flows_enabled).lower() - if site_item.scheduling_flows_enabled is not None: - site_element.attrib["schedulingFlowsEnabled"] = str(site_item.scheduling_flows_enabled).lower() + + flows_edit = str(site_item.editing_flows_enabled).lower() + flows_schedule = str(site_item.scheduling_flows_enabled).lower() + flows_all = str(site_item.flows_enabled).lower() + + self.set_versioned_flow_attributes(flows_all, flows_edit, flows_schedule, parent_srv, site_element, site_item) + if site_item.allow_subscription_attachments is not None: site_element.attrib["allowSubscriptionAttachments"] = str(site_item.allow_subscription_attachments).lower() if site_item.guest_access_enabled is not None: @@ -784,6 +786,32 @@ def create_req(self, site_item: "SiteItem"): return ET.tostring(xml_request) + def set_versioned_flow_attributes(self, flows_all, flows_edit, flows_schedule, parent_srv, site_element, site_item): + if (not parent_srv) or SiteItem.use_new_flow_settings(parent_srv): + if site_item.flows_enabled is not None: + flows_edit = flows_edit or flows_all + flows_schedule = flows_schedule or flows_all + import warnings + + warnings.warn( + "FlowsEnabled has been removed and become two options:" + " SchedulingFlowsEnabled and EditingFlowsEnabled" + ) + if site_item.editing_flows_enabled is not None: + site_element.attrib["editingFlowsEnabled"] = flows_edit + if site_item.scheduling_flows_enabled is not None: + site_element.attrib["schedulingFlowsEnabled"] = flows_schedule + + else: + if site_item.flows_enabled is not None: + site_element.attrib["flowsEnabled"] = str(site_item.flows_enabled).lower() + if site_item.editing_flows_enabled is not None or site_item.scheduling_flows_enabled is not None: + flows_all = flows_all or flows_edit or flows_schedule + site_element.attrib["flowsEnabled"] = flows_all + import warnings + + warnings.warn("In version 3.10 and earlier there is only one option: FlowsEnabled") + class TableRequest(object): def update_req(self, table_item): @@ -971,15 +999,15 @@ def embedded_extract_req(self, xml_request, include_all=True, datasources=None): class Connection(object): @_tsrequest_wrapped - def update_req(self, xml_request, connection_item): + def update_req(self, xml_request: ET.Element, connection_item: "ConnectionItem") -> None: connection_element = ET.SubElement(xml_request, "connection") - if connection_item.server_address: + if connection_item.server_address is not None: connection_element.attrib["serverAddress"] = connection_item.server_address.lower() - if connection_item.server_port: + if connection_item.server_port is not None: connection_element.attrib["serverPort"] = str(connection_item.server_port) - if connection_item.username: + if connection_item.username is not None: connection_element.attrib["userName"] = connection_item.username - if connection_item.password: + if connection_item.password is not None: connection_element.attrib["password"] = connection_item.password if connection_item.embed_password is not None: connection_element.attrib["embedPassword"] = str(connection_item.embed_password).lower() diff --git a/tableauserverclient/server/request_options.py b/tableauserverclient/server/request_options.py index 4462ba786..f4ed8fd3c 100644 --- a/tableauserverclient/server/request_options.py +++ b/tableauserverclient/server/request_options.py @@ -1,4 +1,7 @@ from ..models.property_decorators import property_is_int +import logging + +logger = logging.getLogger("tableau.request_options") class RequestOptionsBase(object): @@ -8,6 +11,8 @@ def apply_query_params(self, url): params = self.get_query_params() params_list = ["{}={}".format(k, v) for (k, v) in params.items()] + logger.debug("Applying options to request: <%s(%s)>", self.__class__.__name__, ",".join(params_list)) + if "?" in url: url, existing_params = url.split("?") params_list.append(existing_params) @@ -142,6 +147,28 @@ def get_query_params(self): return params +class ExcelRequestOptions(RequestOptionsBase): + def __init__(self, maxage: int = -1) -> None: + super().__init__() + self.max_age = maxage + + @property + def max_age(self) -> int: + return self._max_age + + @max_age.setter + @property_is_int(range=(0, 240), allowed=[-1]) + def max_age(self, value: int) -> None: + self._max_age = value + + def get_query_params(self): + params = {} + if self.max_age != -1: + params["maxAge"] = self.max_age + + return params + + class ImageRequestOptions(_FilterOptionsBase): # if 'high' isn't specified, the REST API endpoint returns an image with standard resolution class Resolution: diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index e35514474..c82f4a6e2 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -1,8 +1,8 @@ -import urllib3 import requests +import urllib3 + from defusedxml.ElementTree import fromstring from packaging.version import Version - from .endpoint import ( Sites, Views, @@ -30,15 +30,17 @@ Metrics, ) from .endpoint.exceptions import ( - EndpointUnavailableError, ServerInfoEndpointNotFoundError, + EndpointUnavailableError, ) from .exceptions import NotSignedInError from ..namespace import Namespace -import requests -from packaging.version import Version +from .._version import get_versions + +__TSC_VERSION__ = get_versions()["version"] +del get_versions _PRODUCT_TO_REST_VERSION = { "10.0": "2.3", @@ -47,6 +49,9 @@ "9.1": "2.0", "9.0": "2.0", } +minimum_supported_server_version = "2.3" +default_server_version = "2.3" +client_version_header = "X-TableauServerClient-Version" class Server(object): @@ -55,7 +60,7 @@ class PublishMode: Overwrite = "Overwrite" CreateNew = "CreateNew" - def __init__(self, server_address, use_server_version=True, http_options=None): + def __init__(self, server_address, use_server_version=False, http_options=None): self._server_address = server_address self._auth_token = None self._site_id = None @@ -63,7 +68,7 @@ def __init__(self, server_address, use_server_version=True, http_options=None): self._session = requests.Session() self._http_options = dict() - self.version = "2.3" + self.version = default_server_version self.auth = Auth(self) self.views = Views(self) self.users = Users(self) @@ -90,8 +95,10 @@ def __init__(self, server_address, use_server_version=True, http_options=None): self.flow_runs = FlowRuns(self) self.metrics = Metrics(self) + # must set this before calling use_server_version, because that's a server call if http_options: self.add_http_options(http_options) + self.add_http_version_header() if use_server_version: self.use_server_version() @@ -101,8 +108,13 @@ def add_http_options(self, options_dict): if options_dict.get("verify") == False: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + def add_http_version_header(self): + if not self._http_options[client_version_header]: + self._http_options.update({client_version_header: __TSC_VERSION__}) + def clear_http_options(self): self._http_options = dict() + self.add_http_version_header() def _clear_auth(self): self._site_id = None @@ -144,13 +156,14 @@ def use_highest_version(self): warnings.warn("use use_server_version instead", DeprecationWarning) - def assert_at_least_version(self, version): + def check_at_least_version(self, target: str): server_version = Version(self.version or "0.0") - minimum_supported = Version(version) - if server_version < minimum_supported: - error = "This endpoint is not available in API version {}. Requires {}".format( - server_version, minimum_supported - ) + target_version = Version(target) + return server_version >= target_version + + def assert_at_least_version(self, comparison: str, reason: str): + if not self.check_at_least_version(comparison): + error = "{} is not available in API version {}. Requires {}".format(reason, self.version, comparison) raise EndpointUnavailableError(error) @property diff --git a/test/assets/Data/user_details.csv b/test/assets/Data/user_details.csv new file mode 100644 index 000000000..15b975942 --- /dev/null +++ b/test/assets/Data/user_details.csv @@ -0,0 +1 @@ +username, pword, , yes, email diff --git a/test/assets/Data/usernames.csv b/test/assets/Data/usernames.csv new file mode 100644 index 000000000..0350c0dd6 --- /dev/null +++ b/test/assets/Data/usernames.csv @@ -0,0 +1,7 @@ +valid, +valid@email.com, +domain/valid, +domain/valid@tmail.com, +va!@#$%^&*()lid, +in@v@lid, +in valid, diff --git a/test/assets/site_get_by_id.xml b/test/assets/site_get_by_id.xml index a47703fb6..a8a1e9a5c 100644 --- a/test/assets/site_get_by_id.xml +++ b/test/assets/site_get_by_id.xml @@ -1,4 +1,4 @@ - - \ No newline at end of file + + diff --git a/test/assets/site_get_by_name.xml b/test/assets/site_get_by_name.xml index 852f9594f..b7ae2b595 100644 --- a/test/assets/site_get_by_name.xml +++ b/test/assets/site_get_by_name.xml @@ -1,4 +1,4 @@ - - \ No newline at end of file + + diff --git a/test/assets/site_update.xml b/test/assets/site_update.xml index dbb166de1..1661a426b 100644 --- a/test/assets/site_update.xml +++ b/test/assets/site_update.xml @@ -1,4 +1,4 @@ - - \ No newline at end of file + + diff --git a/test/test_group.py b/test/test_group.py index d948090ca..306d42170 100644 --- a/test/test_group.py +++ b/test/test_group.py @@ -1,9 +1,7 @@ # encoding=utf-8 -import os import unittest - +import os import requests_mock - import tableauserverclient as TSC from tableauserverclient.datetime_helpers import format_datetime @@ -129,8 +127,7 @@ def test_add_user_before_populating(self) -> None: with requests_mock.mock() as m: m.get(self.baseurl, text=get_xml_response) m.post( - "http://test/api/2.3/sites/dad65087-b08b-4603-af4e-2887b8aafc67/groups/ef8b19c0-43b6-11e6-af50" - "-63f5805dbe3c/users", + self.baseurl + "/ef8b19c0-43b6-11e6-af50-63f5805dbe3c/users", text=add_user_response, ) all_groups, pagination_item = self.server.groups.get() @@ -163,8 +160,7 @@ def test_remove_user_before_populating(self) -> None: with requests_mock.mock() as m: m.get(self.baseurl, text=response_xml) m.delete( - "http://test/api/2.3/sites/dad65087-b08b-4603-af4e-2887b8aafc67/groups/ef8b19c0-43b6-11e6-af50" - "-63f5805dbe3c/users/5de011f8-5aa9-4d5b-b991-f462c8dd6bb7", + self.baseurl + "/ef8b19c0-43b6-11e6-af50-63f5805dbe3c/users/5de011f8-5aa9-4d5b-b991-f462c8dd6bb7", text="ok", ) all_groups, pagination_item = self.server.groups.get() diff --git a/test/test_project.py b/test/test_project.py index 1d210eeb1..48e6005af 100644 --- a/test/test_project.py +++ b/test/test_project.py @@ -4,6 +4,7 @@ import requests_mock import tableauserverclient as TSC +from tableauserverclient import GroupItem from ._utils import read_xml_asset, asset TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") @@ -120,7 +121,7 @@ def test_update_datasource_default_permission(self) -> None: capabilities = {TSC.Permission.Capability.ExportXml: TSC.Permission.Mode.Deny} - rules = [TSC.PermissionsRule(grantee=group.to_reference(), capabilities=capabilities)] + rules = [TSC.PermissionsRule(grantee=GroupItem.as_reference(group._id), capabilities=capabilities)] new_rules = self.server.projects.update_datasource_default_permissions(project, rules) @@ -237,7 +238,7 @@ def test_delete_permission(self) -> None: if permission.grantee.id == single_group._id: capabilities = permission.capabilities - rules = TSC.PermissionsRule(grantee=single_group.to_reference(), capabilities=capabilities) + rules = TSC.PermissionsRule(grantee=GroupItem.as_reference(single_group._id), capabilities=capabilities) endpoint = "{}/permissions/groups/{}".format(single_project._id, single_group._id) m.delete("{}/{}/Read/Allow".format(self.baseurl, endpoint), status_code=204) @@ -283,7 +284,7 @@ def test_delete_workbook_default_permission(self) -> None: TSC.Permission.Capability.ChangePermissions: TSC.Permission.Mode.Allow, } - rules = TSC.PermissionsRule(grantee=single_group.to_reference(), capabilities=capabilities) + rules = TSC.PermissionsRule(grantee=GroupItem.as_reference(single_group._id), capabilities=capabilities) endpoint = "{}/default-permissions/workbooks/groups/{}".format(single_project._id, single_group._id) m.delete("{}/{}/Read/Allow".format(self.baseurl, endpoint), status_code=204) diff --git a/test/test_requests.py b/test/test_requests.py index 82859dd26..5c0d090ba 100644 --- a/test/test_requests.py +++ b/test/test_requests.py @@ -41,6 +41,7 @@ def test_make_post_request(self): ) self.assertEqual(resp.request.headers["x-tableau-auth"], "j80k54ll2lfMZ0tv97mlPvvSCRyD0DOM") self.assertEqual(resp.request.headers["content-type"], "multipart/mixed") + self.assertTrue(re.search("Tableau Server Client", resp.request.headers["user-agent"])) self.assertEqual(resp.request.body, b"1337") # Test that 500 server errors are handled properly diff --git a/test/test_site.py b/test/test_site.py index 23eb99ddd..b8469e56c 100644 --- a/test/test_site.py +++ b/test/test_site.py @@ -24,6 +24,9 @@ def setUp(self) -> None: self.server._site_id = "0626857c-1def-4503-a7d8-7907c3ff9d9f" self.baseurl = self.server.sites.baseurl + # sites APIs can only be called on the site being logged in to + self.logged_in_site = self.server.site_id + def test_get(self) -> None: with open(GET_XML, "rb") as f: response_xml = f.read().decode("utf-8") @@ -71,10 +74,10 @@ def test_get_by_id(self) -> None: with open(GET_BY_ID_XML, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: - m.get(self.baseurl + "/dad65087-b08b-4603-af4e-2887b8aafc67", text=response_xml) - single_site = self.server.sites.get_by_id("dad65087-b08b-4603-af4e-2887b8aafc67") + m.get(self.baseurl + "/" + self.logged_in_site, text=response_xml) + single_site = self.server.sites.get_by_id(self.logged_in_site) - self.assertEqual("dad65087-b08b-4603-af4e-2887b8aafc67", single_site.id) + self.assertEqual(self.logged_in_site, single_site.id) self.assertEqual("Active", single_site.state) self.assertEqual("Default", single_site.name) self.assertEqual("ContentOnly", single_site.admin_mode) @@ -95,7 +98,7 @@ def test_get_by_name(self) -> None: m.get(self.baseurl + "/testsite?key=name", text=response_xml) single_site = self.server.sites.get_by_name("testsite") - self.assertEqual("dad65087-b08b-4603-af4e-2887b8aafc67", single_site.id) + self.assertEqual(self.logged_in_site, single_site.id) self.assertEqual("Active", single_site.state) self.assertEqual("testsite", single_site.name) self.assertEqual("ContentOnly", single_site.admin_mode) @@ -110,7 +113,7 @@ def test_update(self) -> None: with open(UPDATE_XML, "rb") as f: response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: - m.put(self.baseurl + "/6b7179ba-b82b-4f0f-91ed-812074ac5da6", text=response_xml) + m.put(self.baseurl + "/" + self.logged_in_site, text=response_xml) single_site = TSC.SiteItem( name="Tableau", content_url="tableau", @@ -143,10 +146,11 @@ def test_update(self) -> None: tier_explorer_capacity=5, tier_viewer_capacity=5, ) - single_site._id = "6b7179ba-b82b-4f0f-91ed-812074ac5da6" + single_site._id = self.logged_in_site + self.server.sites.parent_srv = self.server single_site = self.server.sites.update(single_site) - self.assertEqual("6b7179ba-b82b-4f0f-91ed-812074ac5da6", single_site.id) + self.assertEqual(self.logged_in_site, single_site.id) self.assertEqual("tableau", single_site.content_url) self.assertEqual("Suspended", single_site.state) self.assertEqual("Tableau", single_site.name) diff --git a/test/test_user.py b/test/test_user.py index b8fe32388..1f5eba57f 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -1,5 +1,8 @@ +import io import os import unittest +from typing import List +from unittest.mock import MagicMock import requests_mock @@ -17,6 +20,9 @@ GET_FAVORITES_XML = os.path.join(TEST_ASSET_DIR, "favorites_get.xml") POPULATE_GROUPS_XML = os.path.join(TEST_ASSET_DIR, "user_populate_groups.xml") +USERNAMES = os.path.join(TEST_ASSET_DIR, "Data", "usernames.csv") +USERS = os.path.join(TEST_ASSET_DIR, "Data", "user_details.csv") + class UserTests(unittest.TestCase): def setUp(self) -> None: @@ -212,3 +218,21 @@ def test_populate_groups(self) -> None: self.assertEqual("86a66d40-f289-472a-83d0-927b0f954dc8", group_list[2].id) self.assertEqual("TableauExample", group_list[2].name) self.assertEqual("local", group_list[2].domain_name) + + def test_get_usernames_from_file(self): + with open(ADD_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + m.post(self.server.users.baseurl, text=response_xml) + user_list, failures = self.server.users.create_from_file(USERNAMES) + assert user_list[0].name == "Cassie", user_list + assert failures == [], failures + + def test_get_users_from_file(self): + with open(ADD_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + m.post(self.server.users.baseurl, text=response_xml) + users, failures = self.server.users.create_from_file(USERS) + assert users[0].name == "Cassie", users + assert failures == [] diff --git a/test/test_user_model.py b/test/test_user_model.py index ba70b1c7c..32d808f52 100644 --- a/test/test_user_model.py +++ b/test/test_user_model.py @@ -1,4 +1,10 @@ +import logging import unittest +from unittest.mock import * +from typing import List +import io + +import pytest import tableauserverclient as TSC @@ -23,3 +29,111 @@ def test_invalid_site_role(self): user = TSC.UserItem("me", TSC.UserItem.Roles.Publisher) with self.assertRaises(ValueError): user.site_role = "Hello" + + +class UserDataTest(unittest.TestCase): + + logger = logging.getLogger("UserDataTest") + + role_inputs = [ + ["creator", "system", "yes", "SiteAdministrator"], + ["None", "system", "no", "SiteAdministrator"], + ["explorer", "SysTEm", "no", "SiteAdministrator"], + ["creator", "site", "yes", "SiteAdministratorCreator"], + ["explorer", "site", "yes", "SiteAdministratorExplorer"], + ["creator", "SITE", "no", "SiteAdministratorCreator"], + ["creator", "none", "yes", "Creator"], + ["explorer", "none", "yes", "ExplorerCanPublish"], + ["viewer", "None", "no", "Viewer"], + ["explorer", "no", "yes", "ExplorerCanPublish"], + ["EXPLORER", "noNO", "yes", "ExplorerCanPublish"], + ["explorer", "no", "no", "Explorer"], + ["unlicensed", "none", "no", "Unlicensed"], + ["Chef", "none", "yes", "Unlicensed"], + ["yes", "yes", "yes", "Unlicensed"], + ] + + valid_import_content = [ + "username, pword, fname, creator, site, yes, email", + "username, pword, fname, explorer, none, no, email", + "", + "u", + "p", + ] + + valid_username_content = ["jfitzgerald@tableau.com"] + + usernames = [ + "valid", + "valid@email.com", + "domain/valid", + "domain/valid@tmail.com", + "va!@#$%^&*()lid", + "in@v@lid", + "in valid", + "", + ] + + def test_validate_usernames(self): + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[0]) + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[1]) + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[2]) + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[3]) + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[4]) + with self.assertRaises(AttributeError): + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[5]) + with self.assertRaises(AttributeError): + TSC.UserItem.validate_username_or_throw(UserDataTest.usernames[6]) + + def test_evaluate_role(self): + for line in UserDataTest.role_inputs: + actual = TSC.UserItem.CSVImport._evaluate_site_role(line[0], line[1], line[2]) + assert actual == line[3], line + [actual] + + def test_get_user_detail_empty_line(self): + test_line = "" + test_user = TSC.UserItem.CSVImport.create_user_from_line(test_line) + assert test_user is None + + def test_get_user_detail_standard(self): + test_line = "username, pword, fname, license, admin, pub, email" + test_user: TSC.UserItem = TSC.UserItem.CSVImport.create_user_from_line(test_line) + assert test_user.name == "username", test_user.name + assert test_user.fullname == "fname", test_user.fullname + assert test_user.site_role == "Unlicensed", test_user.site_role + assert test_user.email == "email", test_user.email + + def test_get_user_details_only_username(self): + test_line = "username" + test_user: TSC.UserItem = TSC.UserItem.CSVImport.create_user_from_line(test_line) + + def test_populate_user_details_only_some(self): + values = "username, , , creator, admin" + user = TSC.UserItem.CSVImport.create_user_from_line(values) + assert user.name == "username" + + def test_validate_user_detail_standard(self): + test_line = "username, pword, fname, creator, site, 1, email" + TSC.UserItem.CSVImport._validate_import_line_or_throw(test_line, UserDataTest.logger) + TSC.UserItem.CSVImport.create_user_from_line(test_line) + + # for file handling + def _mock_file_content(self, content: List[str]) -> io.TextIOWrapper: + # the empty string represents EOF + # the tests run through the file twice, first to validate then to fetch + mock = MagicMock(io.TextIOWrapper) + content.append("") # EOF + mock.readline.side_effect = content + mock.name = "file-mock" + return mock + + def test_validate_import_file(self): + test_data = self._mock_file_content(UserDataTest.valid_import_content) + valid, invalid = TSC.UserItem.CSVImport.validate_file_for_import(test_data, UserDataTest.logger) + assert valid == 2, "Expected two lines to be parsed, got {}".format(valid) + assert invalid == [], "Expected no failures, got {}".format(invalid) + + def test_validate_usernames_file(self): + test_data = self._mock_file_content(UserDataTest.usernames) + valid, invalid = TSC.UserItem.CSVImport.validate_file_for_import(test_data, UserDataTest.logger) + assert valid == 5, "Exactly 5 of the lines were valid, counted {}".format(valid + invalid) diff --git a/versioneer.py b/versioneer.py index 59211ed6f..86c240e13 100755 --- a/versioneer.py +++ b/versioneer.py @@ -277,6 +277,7 @@ """ from __future__ import print_function + try: import configparser except ImportError: @@ -308,11 +309,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -325,8 +328,7 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + print("Warning: build in %s is using versioneer.py from %s" % (os.path.dirname(me), versioneer_py)) except NameError: pass return root @@ -348,6 +350,7 @@ def get(parser, name): if parser.has_option("versioneer", name): return parser.get("versioneer", name) return None + cfg = VersioneerConfig() cfg.VCS = VCS cfg.style = get(parser, "style") or "" @@ -372,17 +375,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -390,10 +394,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None) + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -418,7 +421,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, p.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY[ + "git" +] = ''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -993,7 +998,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1002,7 +1007,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1010,19 +1015,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1037,8 +1049,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1046,10 +1057,9 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, ["describe", "--tags", "--dirty", "--always", "--long", "--match", "%s*" % tag_prefix], cwd=root + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1072,17 +1082,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1091,10 +1100,9 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (full_tag, tag_prefix) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1105,13 +1113,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -1167,16 +1173,19 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print("Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1205,11 +1214,9 @@ def versions_from_file(filename): contents = f.read() except EnvironmentError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1218,8 +1225,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1251,8 +1257,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1366,11 +1371,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1390,9 +1397,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1415,8 +1426,7 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert cfg.versionfile_source is not None, "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1470,9 +1480,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1521,6 +1535,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools @@ -1553,14 +1568,15 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1581,17 +1597,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.distutils_buildexe import py2exe as _py2exe # py3 except ImportError: @@ -1610,13 +1630,17 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments @@ -1643,8 +1667,8 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file(target_versionfile, self._versioneer_generated_versions) + cmds["sdist"] = cmd_sdist return cmds @@ -1699,11 +1723,9 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except (EnvironmentError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1712,15 +1734,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: @@ -1762,8 +1787,7 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print(" appending versionfile_source ('%s') to MANIFEST.in" % cfg.versionfile_source) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: