Spaces:
Runtime error
Runtime error
| from functools import lru_cache | |
| import json | |
| import math | |
| import re | |
| from typing import List, Union | |
| from pathlib import Path | |
| import torch | |
| from torch import Tensor | |
| def load_text_file( | |
| file_path: Union[Path, str], | |
| encoding='utf-8', | |
| *args, **kwargs | |
| ) -> str: | |
| with open(file_path, 'r', encoding=encoding) as f: | |
| data = f.read() | |
| return data | |
| def save_text_file( | |
| file_path: Union[Path, str], | |
| data: str, | |
| encoding='utf-8' | |
| ) -> str: | |
| with open(file_path, 'w', encoding=encoding) as f: | |
| data = f.write(data) | |
| return data | |
| def remove_long_spaces(line: str) -> str: | |
| return re.sub('\s{2,}', ' ', line) | |
| def get_positionals(max_length: int, d_model: int) -> Tensor: | |
| """Create Positionals tensor to be added to the input | |
| Args: | |
| max_length (int): The maximum length of the positionals sequence. | |
| d_model (int): The dimensionality of the positionals sequence. | |
| Returns: | |
| Tensor: Positional tensor | |
| """ | |
| result = torch.zeros(max_length, d_model, dtype=torch.float) | |
| for pos in range(max_length): | |
| for i in range(0, d_model, 2): | |
| denominator = pow(10000, 2 * i / d_model) | |
| result[pos, i] = math.sin(pos / denominator) | |
| result[pos, i + 1] = math.cos(pos / denominator) | |
| return result | |
| def load_json(file_path: Union[Path, str]) -> Union[dict, list]: | |
| with open(file_path, 'r') as f: | |
| data = json.load(f) | |
| return data | |
| def save_json( | |
| file_path: Union[Path, str], data: Union[dict, list] | |
| ) -> None: | |
| with open(file_path, 'w') as f: | |
| json.dump(data, f) | |
| def get_freq_dict(data: List[str]) -> dict: | |
| freq = {} | |
| for item in data: | |
| for word in item.split(' '): | |
| if word in freq: | |
| freq[word] += 1 | |
| else: | |
| freq[word] = 1 | |
| return freq | |
| def load_state(state_path: Union[Path, str]): | |
| state = torch.load(state_path) | |
| model = state['model'] | |
| model = { | |
| key.replace('module.', ''): value | |
| for key, value in model.items() | |
| } | |
| optimizer = state['optimizer'] | |
| epoch = state['epoch'] | |
| steps = state['steps'] | |
| return model, optimizer, epoch, steps | |