Instructions to use studyOverflow/egrpo with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use studyOverflow/egrpo with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("studyOverflow/egrpo", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| #This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. | |
| import os | |
| import torch.distributed as dist | |
| class COMM_INFO: | |
| def __init__(self): | |
| self.group = None | |
| self.sp_size = 1 | |
| self.global_rank = 0 | |
| self.rank_within_group = 0 | |
| self.group_id = 0 | |
| nccl_info = COMM_INFO() | |
| _SEQUENCE_PARALLEL_STATE = False | |
| def initialize_sequence_parallel_state(sequence_parallel_size): | |
| global _SEQUENCE_PARALLEL_STATE | |
| if sequence_parallel_size > 1: | |
| _SEQUENCE_PARALLEL_STATE = True | |
| initialize_sequence_parallel_group(sequence_parallel_size) | |
| else: | |
| nccl_info.sp_size = 1 | |
| nccl_info.global_rank = int(os.getenv("RANK", "0")) | |
| nccl_info.rank_within_group = 0 | |
| nccl_info.group_id = int(os.getenv("RANK", "0")) | |
| def set_sequence_parallel_state(state): | |
| global _SEQUENCE_PARALLEL_STATE | |
| _SEQUENCE_PARALLEL_STATE = state | |
| def get_sequence_parallel_state(): | |
| return _SEQUENCE_PARALLEL_STATE | |
| def initialize_sequence_parallel_group(sequence_parallel_size): | |
| """Initialize the sequence parallel group.""" | |
| rank = int(os.getenv("RANK", "0")) | |
| world_size = int(os.getenv("WORLD_SIZE", "1")) | |
| assert ( | |
| world_size % sequence_parallel_size == 0 | |
| ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( | |
| world_size, sequence_parallel_size) | |
| nccl_info.sp_size = sequence_parallel_size | |
| nccl_info.global_rank = rank | |
| num_sequence_parallel_groups: int = world_size // sequence_parallel_size | |
| for i in range(num_sequence_parallel_groups): | |
| ranks = range(i * sequence_parallel_size, | |
| (i + 1) * sequence_parallel_size) | |
| group = dist.new_group(ranks) | |
| if rank in ranks: | |
| nccl_info.group = group | |
| nccl_info.rank_within_group = rank - i * sequence_parallel_size | |
| nccl_info.group_id = i | |
| def destroy_sequence_parallel_group(): | |
| """Destroy the sequence parallel group.""" | |
| dist.destroy_process_group() | |