Skip to content

Support local offline environment to read community config and model files #5817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
@@ -340,15 +340,21 @@ def _from_pretrained(
logger.warning(f"{config_file} is not a valid path to a model config file")
# Assuming from community-contributed pretrained models
else:
cached_standard_config = os.path.join(cache_dir, cls.model_config_file)
cached_legacy_config = os.path.join(cache_dir, cls.legacy_model_config_file)
standard_community_url = "/".join(
[COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.model_config_file]
)
legacy_community_url = "/".join(
[COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, cls.legacy_model_config_file]
)
try:
if url_file_exists(standard_community_url):
if os.path.isfile(cached_standard_config):
resolved_vocab_file = cached_standard_config
elif url_file_exists(standard_community_url):
resolved_vocab_file = get_path_from_url_with_filelock(standard_community_url, cache_dir)
elif os.path.isfile(cached_legacy_config):
resolved_vocab_file = cached_legacy_config
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样应该可以解决上面提到的 AutoModel 无法在离线环境使用已经下载好的社区模型问题了

通过已下载模型的绝对路径进行加载确实能够在离线环境下使用,但第一次下载社区模型用 community/model-name 的形式,下载后在离线环境使用时,把代码再改成 /path/to/community/model-name 的形式,总是要麻烦一些,不如直接在下载社区模型前检查一下 cache_dir 中是否已经包含下载好的文件,有的话直接使用,也避免了去校验配置文件是否存在等的网络请求

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif url_file_exists(legacy_community_url):
logger.info("Standard config do not exist, loading from legacy config")
resolved_vocab_file = get_path_from_url_with_filelock(legacy_community_url, cache_dir)
2 changes: 2 additions & 0 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -841,6 +841,8 @@ def _get_config_dict(
resolved_config_file = resolve_hf_config_path(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, subfolder=subfolder
)
elif os.path.isfile(os.path.join(cache_dir, CONFIG_NAME)):
resolved_config_file = os.path.join(cache_dir, CONFIG_NAME)
Comment on lines +844 to +845
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是会优先根据变量pretrained_model_name_or_path(有可能是 file-path、directory-path、url)来加载文件,cache_dir 只是作为下载文件的缓存路径,并不能作为检索文件的路径。

如果你想从 cache_dir 来加载文件的话,可以手动指定路径:

# 从默认缓存路径加载模型
AutoModel.from_pretrained("bert-base-uncased")

# 从自定义缓存路径加载模型
AutoModel.from_pretrained("cache_dir")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢您的 review,在不进行上述代码调整的 description 所描述环境中,我联网时执行如下命令,能够成功:

>>> from paddlenlp.transformers import AutoModel
>>> AutoModel.from_pretrained("Salesforce/codegen-350M-mono")

但当我断开网络连接,再次执行时,则会遇到与上面描述中类似的问题(目前这个 PR 的改动也没能解决这个方法在离线环境不可用的问题),堆栈信息如下:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connection.py", line 174, in _new_conn
    conn = connection.create_connection(
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/util/connection.py", line 72, in create_connection
    for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/socket.py", line 954, in getaddrinfo
    for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
socket.gaierror: [Errno 8] nodename nor servname provided, or not known

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connectionpool.py", line 703, in urlopen
    httplib_response = self._make_request(
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connectionpool.py", line 386, in _make_request
    self._validate_conn(conn)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connectionpool.py", line 1042, in _validate_conn
    conn.connect()
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connection.py", line 363, in connect
    self.sock = conn = self._new_conn()
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connection.py", line 186, in _new_conn
    raise NewConnectionError(
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x143035160>: Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/requests/adapters.py", line 489, in send
    resp = conn.urlopen(
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/connectionpool.py", line 787, in urlopen
    retries = retries.increment(
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/urllib3/util/retry.py", line 592, in increment
    raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='bj.bcebos.com', port=443): Max retries exceeded with url: /paddlenlp/models/community/Salesforce/codegen-350M-mono/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x143035160>: Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/paddlenlp/transformers/auto/modeling.py", line 478, in from_pretrained
    return cls._from_pretrained(pretrained_model_name_or_path, task, *model_args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/paddlenlp/transformers/auto/modeling.py", line 341, in _from_pretrained
    if url_file_exists(standard_community_url):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/paddlenlp/utils/downloader.py", line 440, in url_file_exists
    result = requests.head(url)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/requests/api.py", line 100, in head
    return request("head", url, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/requests/api.py", line 59, in request
    return session.request(method=method, url=url, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/requests/sessions.py", line 587, in request
    resp = self.send(prep, **send_kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/requests/sessions.py", line 701, in send
    r = adapter.send(request, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/requests/adapters.py", line 565, in send
    raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPSConnectionPool(host='bj.bcebos.com', port=443): Max retries exceeded with url: /paddlenlp/models/community/Salesforce/codegen-350M-mono/config.json (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x143035160>: Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))

不知道我对 从默认缓存路径加载模型 这个用法的理解和传入的参数是否有误

else:
community_url = "/".join([COMMUNITY_MODEL_PREFIX, pretrained_model_name_or_path, CONFIG_NAME])
if url_file_exists(community_url):
2 changes: 2 additions & 0 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
@@ -877,6 +877,8 @@ def _resolve_model_file_path(
# 0. when it is local file
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isfile(os.path.join(cache_dir, cls.resource_files_names["model_state"])):
return os.path.join(cache_dir, cls.resource_files_names["model_state"])
Comment on lines +880 to +881
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此部分代码说明同上。


# 1. when it is model-name
if pretrained_model_name_or_path in cls.pretrained_init_configuration: