Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d5d6136
Add arguments for 2D pose model overrides in opts class
xiu-cs May 15, 2026
3aea87e
Add configuration file for HRNet-w32 backbone fine-tuned on Animal3D
xiu-cs May 15, 2026
0d7d16f
Add initial configuration file for animal 2D detector and HRNet-w32 f…
xiu-cs May 15, 2026
cfe20da
Enhance SuperAnimalConfig with fine-tuning options and detailed mode …
xiu-cs May 15, 2026
f66e63f
Update vis_animals.sh to include saved_2d_model_path and reset saved_…
xiu-cs May 15, 2026
59c960d
Refactor 2D pose estimation in vis_animals.py to use SuperAnimalEstim…
xiu-cs May 15, 2026
06cbfc1
Refactor resolve_weights_path function in weights.py for improved cla…
xiu-cs May 15, 2026
832c2d4
Update README.md to clarify pre-trained model usage and auto-download…
xiu-cs May 15, 2026
c439d22
Remove unused joint variables from vis_animals.sh
xiu-cs May 15, 2026
674afbe
Update model weights path resolution to include file extension
xiu-cs May 15, 2026
e4ec633
Enhance SuperAnimalEstimator to support fine-tuned model loading and …
xiu-cs May 15, 2026
090dcfc
Update README.md to enhance SuperAnimalEstimator description with fin…
xiu-cs May 15, 2026
95293b4
Update model weights path resolution to include file extension for lo…
xiu-cs May 15, 2026
87dc5f4
Remove unused modules and classes related to 3D pose estimation
xiu-cs May 15, 2026
ed3f084
Update DatasetConfig references from "rat7m" to "animal3d" and adjust…
xiu-cs May 15, 2026
bc4a886
Remove references to Rat7M dataset from Graph class and related docum…
xiu-cs May 15, 2026
f32ca40
Update action placeholder in main_animal3d.py for Animal3D dataset
xiu-cs May 15, 2026
54d1cb6
Update dataset default value and root path in arguments.py for Animal…
xiu-cs May 15, 2026
ff801e5
Add SuperAnimalConfig support and unit tests for fine-tuned mode in S…
xiu-cs May 15, 2026
3279e88
Refactor get_pose2D function to streamline 2D pose estimation using S…
xiu-cs May 15, 2026
4c20e0f
Refactor 2D and 3D pose estimation functions to build estimators once…
xiu-cs May 15, 2026
54ffab2
Add 'mot' to ignore words list in Codespell workflow
xiu-cs May 17, 2026
e77e7f0
Refactor model loading to use device-agnostic code for CUDA compatibi…
xiu-cs May 17, 2026
6335c0b
Refactor HRNetPose2d to use device-agnostic code for model and input …
xiu-cs May 17, 2026
4ebd20c
Update Python version requirement and refine dependency constraints i…
xiu-cs May 17, 2026
e3fef16
Remove unused import of coco_h36m from utilitys.py
xiu-cs May 17, 2026
621b6a6
Remove unnecessary configuration files and associated data
xiu-cs May 17, 2026
34a824f
Update README.md to clarify Python version requirement and add PyTorc…
xiu-cs May 17, 2026
c6b1c4a
Fix checkpoint directory handling and improve weight loading logic in…
xiu-cs May 19, 2026
6131b5e
Refactor model path handling in test_animal3d.sh to clarify usage of …
xiu-cs May 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/codespell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
- name: Codespell
uses: codespell-project/actions-codespell@v1
with:
ignore_words_list: fmpose, mpjpe, uvd, xyz, hm36, cpn, dbb
ignore_words_list: fmpose, mpjpe, uvd, xyz, hm36, cpn, dbb, mot
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ FMPose3D creates a 3D pose from a single 2D image. It leverages fast Flow Matchi

### Set up an environment

Make sure you have Python 3.10+. You can set this up with:
Make sure you have Python 3.10. The installation and demos are tested with Python 3.10. You can set this up with:
```bash
conda create -n fmpose_3d python=3.10
conda activate fmpose_3d
Expand All @@ -45,6 +45,8 @@ For the animal pipeline, install the optional DeepLabCut dependency:
pip install "fmpose3d[animals]"
```

> **PyTorch/CUDA note.** FMPose3D pins `torch>=2.4.1,<2.5` and `torchvision>=0.19.1,<0.20`, which use CUDA 12.1 wheels by default on Linux. If your driver does not support CUDA 12.1, or if you need a specific CUDA build, install PyTorch first using the matching command from [pytorch.org](https://pytorch.org/get-started/locally/), then install `fmpose3d`.

## Demos

### Testing on in-the-wild images (humans)
Expand Down Expand Up @@ -108,7 +110,7 @@ FMPose3D also ships a high-level Python API for end-to-end 3D pose estimation fr

## Experiments on non-human animals

For animal training/testing and demo scripts, see [animals/README.md](animals/README.md).
For animal training/testing and demo scripts, see [animals/README.md](animals/README.md). The animal demo **auto-downloads both checkpoints** (a 26-joint SuperAnimal-Quadruped fine-tuned on Animal3D for 2D pose, and the FMPose3D animal flow-matching lifter for 3D) from [Hugging Face](https://huggingface.co/MLAdaptiveIntelligence/FMPose3D) on first run — no manual setup needed.

## Citation

Expand Down
8 changes: 5 additions & 3 deletions animals/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ In this part, the FMPose3D model is trained on [Animal3D](https://xujiacong.gith

This visualization script is designed for single-frame based model, allowing you to easily run 3D animal pose estimation on any single image.

Before testing, make sure you have the pre-trained model ready.
You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1kL4aOyWNq0o9zB0rSTRM8KYgkySVmUTk?usp=drive_link) and place it in the `./pre_trained_models` directory.
Both pre-trained checkpoints are **auto-downloaded from [Hugging Face](https://huggingface.co/MLAdaptiveIntelligence/FMPose3D)** on first run and cached under `~/.cache/huggingface/`. No manual downloads required.

- **3D lifter** (`fmpose3d_animals.pth`) — Animal3D 26-joint flow-matching 2D→3D lifter. Override: set `saved_model_path` in `vis_animals.sh` to a local `.pth`.
- **2D pose model** (`sa_finetune_hrnet_w32.pt`) — SuperAnimal-Quadruped HRNet-w32 fine-tuned on Animal3D for the 26-joint Animal3D output layout. Override: set `saved_2d_model_path` in `vis_animals.sh` to a local `.pt`.

Next, put your test images into folder `demo/images`. Then run the visualization script:
```bash
Expand Down Expand Up @@ -49,7 +51,7 @@ Place the downloaded files in the `dataset/` folder of this project:
## Training
The training logs, checkpoints, and related files of each training time will be saved in the './checkpoint' folder.

For trainig on the two datasets:
For training on the two datasets:

```bash
cd animals
Expand Down
233 changes: 71 additions & 162 deletions animals/demo/vis_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,6 @@
from fmpose3d.models import get_model
CFM = get_model(args.model_type)

try:
from deeplabcut.pose_estimation_pytorch.apis import ( # pyright: ignore[reportMissingImports]
superanimal_analyze_images,
)
except ImportError:
raise ImportError(
"DeepLabCut is required for the animal demo. "
"Install it with: pip install \"fmpose3d[animals]\""
) from None

superanimal_name = "superanimal_quadruped"
model_name = "hrnet_w32"
detector_name = "fasterrcnn_resnet50_fpn_v2"
max_individuals = 1

def compute_limb_regularization_matrix(gt_3d):
"""
Compute regularization matrix to align limb directions to vertical (0,0,1).
Expand Down Expand Up @@ -145,108 +130,39 @@ def apply_regularization(pose_3d, R):
"""
return (R @ pose_3d.T).T

def get_pose2D(path, output_dir, type):
def build_2d_estimator():
"""Build the 2D pose estimator once. Snapshot resolves lazily on first predict.

Empty --saved_2d_model_path -> auto-download fine-tuned snapshot from HF.
Non-empty path -> use as a local override.
"""
from fmpose3d.common.config import SuperAnimalConfig
from fmpose3d.inference_api.fmpose3d import SuperAnimalEstimator
from fmpose3d.utils.weights import resolve_weights_path

pose_snapshot_path = resolve_weights_path(
args.saved_2d_model_path, "sa_finetune_hrnet_w32.pt"
)
cfg = SuperAnimalConfig(
pose_snapshot_path=pose_snapshot_path,
pytorch_config_path=args.pytorch_config_2d_path,
)
print(f"[2D] pose snapshot = {cfg.pose_snapshot_path}")
return SuperAnimalEstimator(cfg)


def get_pose2D(estimator, path, output_dir, type):

print('\nGenerating 2D pose...')

# Check if this is the special debug case for 000000119761_horse
filename = Path(path).stem
is_debug_case = "000000119761_horse" in filename

if is_debug_case:
print(f"DEBUG MODE: Using provided 2D pose for {filename}")
# User provided 2D pose (26 keypoints, x, y coordinates, ignoring the last dimension)
provided_pose = np.array([
[361, 230], [361, 237], [363, 279], [257, 359], [251, 374],
[164, 365], [68, 372], [99, 206], [247, 266], [253, 285],
[127, 275], [101, 285], [267, 217], [268, 229], [273, 318],
[250, 340], [128, 311], [76, 305], [313, 220], [48, 310],
[351, 203], [352, 210], [340, 257], [340, 261], [373, 276],
[55, 247]
], dtype=np.float32)

# Reshape to match expected format: (1, 26, 2) for single individual
provided_pose = provided_pose.reshape(1, 26, 2)

# Create xy_preds dict with the provided pose
xy_preds = {path: provided_pose}
print(f"Using provided 2D pose with shape: {provided_pose.shape}")
else:
# Normal prediction flow
predictions = superanimal_analyze_images(
superanimal_name,
model_name,
detector_name,
path,
max_individuals,
out_folder=output_dir
)
print("predictions:", predictions)

# get the 2D keypoints from the predictions
xy_preds = {}
# predictions is a dict: {image_path: {"bodyparts": (N, K, 3), "bboxes": ..., "bbox_scores": ...}}
for img_path, payload in predictions.items():
bodyparts = payload.get("bodyparts")
if bodyparts is None:
continue
# bodyparts shape: (num_individuals, num_keypoints, 3) -> [:, :, :2] keeps x,y
xy_preds[img_path] = bodyparts[..., :2]

print("2D keypoints (x,y) by image:")
for img_path, xy in xy_preds.items():
print(f"{img_path}: shape {xy.shape}")

# For debug case, the provided pose is already in Animal3D format (26 keypoints)
# So we skip the mapping step
if is_debug_case:
print("DEBUG MODE: Skipping keypoint mapping (already in Animal3D format)")
mapped_keypoints = xy_preds
else:
# now map the keypoints to a different set of keypoints (used in Animal3D)
# keypoint mapping from quadruped80K super keypotints to animal3d keypoints
keypoint_mapping = {"quadruped80k":[10, 5, -1, 26, 29, 30, 35, 22, 24, 27, 31, 32, -1, -1, 25, 28, 33, 34, 15, 23, 11, 6, 4, 3, 0, -1]}

# for the keypoint_mapping, -1 indicates that there is no corresponding keypoint in the source set, but we can interpolate
# for index 2, we can interpolate between keypoints 3 and 4 in the source set to get a better estimate of the missing keypoint
# for index 25, we can interpolate between keypoints 22 and 23 in the source set
# for index 12, we can interpolate between keypoints 24 and 19 in the source set
# for index 13, we can interpolate between keypoints 27 and 19 in the source set

# Define interpolation rules for -1 indices: {target_idx: (source_idx1, source_idx2)}
interpolation_rules = {
2: (3, 4), # interpolate between source keypoints 3 and 4
12: (24, 19), # interpolate between source keypoints 24 and 19
13: (27, 19), # interpolate between source keypoints 27 and 19
25: (22, 23), # interpolate between source keypoints 22 and 23
}

# map the keypoints
mapped_keypoints = {}
mapping_indices = keypoint_mapping["quadruped80k"]

for img_path, xy in xy_preds.items():
# xy shape: (num_individuals, num_keypoints, 2)
num_individuals, num_keypoints, _ = xy.shape
num_target_keypoints = len(mapping_indices)

# Initialize mapped array with NaN or zeros
mapped_xy = np.full((num_individuals, num_target_keypoints, 2), np.nan)

for target_idx, source_idx in enumerate(mapping_indices):
if source_idx != -1 and source_idx < num_keypoints:
# Copy the keypoint from source to target position
mapped_xy[:, target_idx, :] = xy[:, source_idx, :]
elif source_idx == -1 and target_idx in interpolation_rules:
# Perform interpolation for -1 indices
src1, src2 = interpolation_rules[target_idx]
if src1 < num_keypoints and src2 < num_keypoints:
# Interpolate as the average of the two source keypoints
mapped_xy[:, target_idx, :] = (xy[:, src1, :] + xy[:, src2, :]) / 2.0
print(f"Interpolated keypoint {target_idx} from source keypoints {src1} and {src2}")

mapped_keypoints[img_path] = mapped_xy
print(f"Mapped {img_path}: {xy.shape} -> {mapped_xy.shape}")

img_bgr = cv2.imread(path)
if img_bgr is None:
raise FileNotFoundError(f"Failed to read image: {path}")

# predict() returns (kpts (1, N, 26, 2), scores (1, N, 26), valid_mask (N,)).
kpts, _scores, _mask = estimator.predict(img_bgr[None])
# Pack into the {img_path: (1, 26, 2)} format expected by the save/vis code below.
mapped_keypoints = {path: kpts[:, 0, :, :]}

print('Generating 2D pose successful!')

Expand All @@ -259,7 +175,6 @@ def get_pose2D(path, output_dir, type):
# Save in the same format as vis_in_the_wild.py for compatibility
output_npz = output_dir_2D + 'keypoints.npz'
np.savez_compressed(output_npz, reconstruction=mapped_xy)
print(f"Saved keypoints to {output_npz}")

# Also save as npy for backup
img_name = Path(img_path).stem
Expand All @@ -275,7 +190,6 @@ def get_pose2D(path, output_dir, type):
index=[f'keypoint_{i}' for i in range(mapped_xy.shape[1])]
)
df.to_csv(csv_file)
print(f"Saved individual {ind_idx} keypoints to {csv_file}")

# Visualize mapped keypoints on image
img = Image.open(img_path)
Expand Down Expand Up @@ -328,39 +242,38 @@ def get_pose2D(path, output_dir, type):
plt.tight_layout()
plt.savefig(vis_file, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f"Saved visualization to {vis_file}")


def get_pose3D(path, output_dir, type='image'):
"""
Generate 3D pose from 2D keypoints using the model.
This function reads the 2D keypoints saved by get_pose2D and generates 3D poses.
def build_3d_lifter():
"""Build the 3D lifter once and return (model, device).

Empty --saved_model_path -> auto-download fmpose3d_animals.pth from HF.
Non-empty path is used as a local override.
"""
print('\nGenerating 3D pose...')
print(f"args.n_joints: {args.n_joints}, args.out_joints: {args.out_joints}")

## Reload model
from fmpose3d.utils.weights import resolve_weights_path

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CFM(args).to(device)

model = {}
model['CFM'] = CFM(args).to(device)

model_dict = model['CFM'].state_dict()
model_path = args.saved_model_path
print(f"Loading model from: {model_path}")
model_path = resolve_weights_path(args.saved_model_path, f"{args.model_type}.pth")
print(f"[3D] lifter weights = {model_path}")
pre_dict = torch.load(model_path, map_location=device, weights_only=True)
for name, key in model_dict.items():
model_dict = model.state_dict()
for name in model_dict:
model_dict[name] = pre_dict[name]
model['CFM'].load_state_dict(model_dict)
print("Model loaded successfully!")

model = model['CFM'].eval()
model.load_state_dict(model_dict)
return model.eval()


def get_pose3D(model, path, output_dir, type='image'):
"""
Generate 3D pose from 2D keypoints using the model.
Reads the 2D keypoints saved by get_pose2D and generates 3D poses.
"""
print('\nGenerating 3D pose...')

## Load input 2D keypoints
keypoints = np.load(output_dir + 'input_2D/keypoints.npz', allow_pickle=True)['reconstruction']
print(f"Loaded keypoints shape: {keypoints.shape}")

## Generate 3D poses
if type == "image":
i = 0
img = cv2.imread(path)
Expand Down Expand Up @@ -422,9 +335,6 @@ def euler_sample(c_2d, y_local, steps, model_3d):
return y_local

## Estimation (without TTA for better results)
print("input_2D.shape:", input_2D.shape)
print("input_2D:", input_2D[0, 0])

# Single inference without flip augmentation
# Create 3D random noise with shape (1, 1, J, 3)
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3, device=device)
Expand Down Expand Up @@ -492,7 +402,6 @@ def euler_sample(c_2d, y_local, steps, model_3d):
output_dir_2D_img = output_dir + 'pose2D_on_image/'
os.makedirs(output_dir_2D_img, exist_ok=True)
cv2.imwrite(f'{output_dir_2D_img}{i:04d}_2d.png', img_copy)
print(f"Saved 2D pose on image to {output_dir_2D_img}{i:04d}_2d.png")

## Save 3D pose as npz
output_dir_3D = output_dir + 'pose3D/'
Expand Down Expand Up @@ -603,46 +512,46 @@ def img2gif(video_path, name, output_dir, duration=0.25):


if __name__ == "__main__":

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

path = args.path # file path or folder path

# Check if path is a directory

# Build the 2D estimator and 3D lifter ONCE; reuse across all images/frames.
# This avoids redundant HF resolution and DLC/torch model reloads.
estimator_2d = build_2d_estimator()
model_3d = build_3d_lifter()

if os.path.isdir(path):
# Get all image files in the directory
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.JPG', '*.JPEG', '*.PNG', '*.BMP']
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(path, ext)))
image_files.sort()

if len(image_files) == 0:
print(f"No image files found in {path}")
exit(0)

print(f"Found {len(image_files)} images in {path}")

# Process each image

for img_path in tqdm(image_files, desc="Processing images"):
filename = img_path.split('/')[-1].split('.')[0]
output_dir = './predictions/' + filename + '/'

print(f"\nProcessing: {img_path}")
get_pose2D(img_path, output_dir, args.type)
get_pose3D(img_path, output_dir, args.type)
get_pose2D(estimator_2d, img_path, output_dir, args.type)
get_pose3D(model_3d, img_path, output_dir, args.type)

print(f'\nAll {len(image_files)} images processed successfully!')
else:
# Single file processing
filename = path.split('/')[-1].split('.')[0]
output_dir = './predictions/' + filename + '/'

get_pose2D(path, output_dir, args.type)
get_pose3D(path, output_dir, args.type)
get_pose2D(estimator_2d, path, output_dir, args.type)
get_pose3D(model_3d, path, output_dir, args.type)

if args.type=="video":
if args.type == "video":
img2video(path, filename, output_dir)
img2gif(path, filename, output_dir)

print('Generating demo successful!')
img2gif(path, filename, output_dir)
Loading
Loading